diff --git a/CHANGELOG.md b/CHANGELOG.md index 804d0eda..35b1a09f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +- Update the Client, it takes a backend type instead of debug=True + env variable to set the spawner - (#210) + ## [0.33.0] - 2022-09-19 ### Removed diff --git a/Makefile b/Makefile index 7485762d..849c44fb 100644 --- a/Makefile +++ b/Makefile @@ -17,21 +17,21 @@ test-remote-sdk: pyclean pytest tests -rs -v --durations=0 -m "not workflows" -n $(PARALLELISM) test-remote-workflows: pyclean - pytest tests -v --durations=0 -m "workflows" - + pytest tests -v --durations=0 -m "workflows" + test-minimal: pyclean pytest tests -rs -v --durations=0 -m "not slow and not workflows" -n $(PARALLELISM) test-local: test-subprocess test-docker test-subprocess-workflows test-docker: pyclean - DEBUG_SPAWNER=docker pytest tests -rs -v --durations=0 -m "not workflows" --local + pytest tests -rs -v --durations=0 -m "not workflows" --mode=docker test-subprocess: pyclean - DEBUG_SPAWNER=subprocess pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --local + pytest tests -rs -v --durations=0 -m "not workflows and not subprocess_skip" --mode=subprocess test-subprocess-workflows: pyclean - DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local + pytest tests -v --durations=0 -m "workflows" --mode=subprocess test-all: test-local test-remote diff --git a/substratest/client.py b/substratest/client.py index 754c5761..274f0d6a 100644 --- a/substratest/client.py +++ b/substratest/client.py @@ -38,7 +38,7 @@ class Client: def __init__( self, - debug: bool, + backend_type: substra.BackendType, organization_id: str, address: str, user: str, @@ -51,7 +51,7 @@ def __init__( super().__init__() self.organization_id = organization_id - self._client = substra.Client(debug=debug, url=address, insecure=False, token=token) + self._client = substra.Client(backend_type=backend_type, url=address, insecure=False, token=token) if not token: token = self._client.login(user, password) self._api_client = _APIClient(address, token) diff --git a/tests/conftest.py b/tests/conftest.py index a3140548..fa01c8f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import uuid import pytest +import substra from substra.sdk.schemas import AlgoCategory import substratest as sbt @@ -55,9 +56,10 @@ def pytest_configure(config): def pytest_addoption(parser): """Command line arguments to configure the network to be local or remote""" parser.addoption( - "--local", - action="store_true", - help="Run the tests on the local backend only. Otherwise run the tests only on the remote backend.", + "--mode", + choices=["subprocess", "docker", "remote"], + default="remote", + help="Choose the mode on which to run the tests", ) parser.addoption( "--nb-train-datasamples", @@ -77,8 +79,8 @@ def pytest_collection_modifyitems(config, items): """Skip the remote tests if local backend and local tests if remote backend. By default, run only on the remote backend. """ - local = config.getoption("--local") - if local: + mode = substra.BackendType(config.getoption("--mode")) + if mode != substra.BackendType.REMOTE: skip_marker = pytest.mark.skip(reason="remove the --local option to run") keyword = "remote_only" else: @@ -90,11 +92,9 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope="session") -def client_debug_local(request): - local = request.config.getoption("--local") - if local: - return True - return False +def client_mode(request): + mode = request.config.getoption("--mode") + return substra.BackendType(mode) class _DataEnv: @@ -129,23 +129,26 @@ class Network: @pytest.fixture -def factory(request, cfg, client_debug_local): +def factory(request, cfg, client_mode): """Factory fixture. Provide class methods to simply create asset specification in order to add them to the substra framework. """ name = f"{TESTS_RUN_UUID}_{request.node.name}" - with sbt.AssetsFactory(name=name, cfg=cfg, client_debug_local=client_debug_local) as f: + with sbt.AssetsFactory( + name=name, + cfg=cfg, + client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]), + ) as f: yield f @pytest.fixture(scope="session") -def cfg(client_debug_local): - if not client_debug_local: +def cfg(client_mode): + if client_mode == substra.BackendType.REMOTE: return settings.Settings.load() else: - # TODO check what enable_intermediate_model_removal does return settings.Settings.load_local_backend() @@ -162,7 +165,7 @@ def debug_factory(request, cfg): @pytest.fixture(scope="session") -def network(cfg, client_debug_local): +def network(cfg, client_mode): """Network fixture. Network must be started outside of the tests environment and the network is kept @@ -174,7 +177,7 @@ def network(cfg, client_debug_local): """ clients = [ sbt.Client( - debug=client_debug_local, + backend_type=client_mode, organization_id=n.msp_id, address=n.address, user=n.user, @@ -191,7 +194,7 @@ def network(cfg, client_debug_local): @pytest.fixture(scope="session") -def default_data_env(cfg, network, client_debug_local): +def default_data_env(cfg, network, client_mode): """Fixture with pre-existing assets in all organizations. The following assets will be created for each organization: @@ -207,7 +210,11 @@ def default_data_env(cfg, network, client_debug_local): """ factory_name = f"{TESTS_RUN_UUID}_global" - with sbt.AssetsFactory(name=factory_name, cfg=cfg, client_debug_local=client_debug_local) as f: + with sbt.AssetsFactory( + name=factory_name, + cfg=cfg, + client_debug_local=(client_mode in [substra.BackendType.LOCAL_SUBPROCESS, substra.BackendType.LOCAL_DOCKER]), + ) as f: datasets = [] metrics = [] for index, client in enumerate(network.clients): @@ -338,17 +345,17 @@ def channel(cfg, network): @pytest.fixture(scope="session") -def debug_client(cfg, client): +def hybrid_client(cfg, client): """ - Client fixture in debug mode (first organization). + Client fixture in hybrid mode (first organization). Use it with @pytest.mark.remote_only """ organization = cfg.organizations[0] - # Debug client and client share the same + # Hybrid client and client share the same # token, otherwise when one connects the other # is disconnected. return sbt.Client( - debug=True, + backend_type=substra.BackendType.LOCAL_DOCKER, organization_id=organization.msp_id, address=organization.address, user=organization.user, diff --git a/tests/test_execution.py b/tests/test_execution.py index 16583575..84767395 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -211,7 +211,7 @@ def test_traintuple_execution_failure(factory, network, default_dataset_1): spec = factory.create_traintuple(algo=algo, inputs=default_dataset_1.train_data_inputs) - if network.clients[0].backend_mode != substra.BackendType.DEPLOYED: + if network.clients[0].backend_mode != substra.BackendType.REMOTE: with pytest.raises(substra.sdk.backends.local.compute.spawner.base.ExecutionError): network.clients[0].add_traintuple(spec) else: @@ -236,7 +236,7 @@ def test_composite_traintuple_execution_failure(factory, client, default_dataset algo = client.add_algo(spec) spec = factory.create_composite_traintuple(algo=algo, inputs=default_dataset.train_data_inputs) - if client.backend_mode == substra.BackendType.DEPLOYED: + if client.backend_mode == substra.BackendType.REMOTE: composite_traintuple = client.add_composite_traintuple(spec) composite_traintuple = client.wait(composite_traintuple, raises=False) @@ -277,7 +277,7 @@ def test_aggregatetuple_execution_failure(factory, client, default_dataset): worker=client.organization_id, ) - if client.backend_mode == substra.BackendType.DEPLOYED: + if client.backend_mode == substra.BackendType.REMOTE: aggregatetuple = client.add_aggregatetuple(spec) aggregatetuple = client.wait(aggregatetuple, raises=False) diff --git a/tests/test_hybrid_debug.py b/tests/test_hybrid_mode.py similarity index 88% rename from tests/test_hybrid_debug.py rename to tests/test_hybrid_mode.py index c511ae94..ca51d6a4 100644 --- a/tests/test_hybrid_debug.py +++ b/tests/test_hybrid_mode.py @@ -21,7 +21,7 @@ def docker_available() -> bool: @pytest.mark.remote_only @pytest.mark.slow -def test_execution_debug(client, debug_client, debug_factory, default_dataset): +def test_execution_debug(client, hybrid_client, debug_factory, default_dataset): spec = debug_factory.create_algo(AlgoCategory.simple) simple_algo = client.add_algo(spec) @@ -36,7 +36,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset): spec = debug_factory.create_traintuple( algo=simple_algo, inputs=default_dataset.opener_input + default_dataset.train_data_sample_inputs[:1] ) - traintuple = debug_client.add_traintuple(spec) + traintuple = hybrid_client.add_traintuple(spec) assert traintuple.status == models.Status.done assert len(traintuple.train.models) != 0 @@ -47,7 +47,7 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset): + default_dataset.train_data_sample_inputs[:1] + FLTaskInputGenerator.train_to_predict(traintuple.key), ) - predicttuple = debug_client.add_predicttuple(spec) + predicttuple = hybrid_client.add_predicttuple(spec) assert predicttuple.status == models.Status.done spec = debug_factory.create_testtuple( @@ -56,19 +56,19 @@ def test_execution_debug(client, debug_client, debug_factory, default_dataset): + default_dataset.train_data_sample_inputs[:1] + FLTaskInputGenerator.predict_to_test(predicttuple.key), ) - testtuple = debug_client.add_testtuple(spec) + testtuple = hybrid_client.add_testtuple(spec) assert testtuple.status == models.Status.done assert list(testtuple.test.perfs.values())[0] == 3 @pytest.mark.remote_only @pytest.mark.slow -def test_debug_compute_plan_aggregate_composite(network, client, debug_client, debug_factory, default_datasets): +def test_debug_compute_plan_aggregate_composite(network, client, hybrid_client, debug_factory, default_datasets): """ Debug / Compute plan version of the `test_aggregate_composite_traintuples` method from `test_execution.py` """ - aggregate_worker = debug_client.organization_id + aggregate_worker = hybrid_client.organization_id number_of_rounds = 2 # register algos on first organization @@ -133,8 +133,8 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d algo=metric, inputs=dataset.train_data_inputs + FLTaskInputGenerator.predict_to_test(spec.predicttuple_id) ) - cp = debug_client.add_compute_plan(cp_spec) - traintuples = debug_client.list_compute_plan_traintuples(cp.key) + cp = hybrid_client.add_compute_plan(cp_spec) + traintuples = hybrid_client.list_compute_plan_traintuples(cp.key) composite_traintuples = client.list_compute_plan_composite_traintuples(cp.key) aggregatetuples = client.list_compute_plan_aggregatetuples(cp.key) predicttuples = client.list_compute_plan_predicttuples(cp.key) @@ -146,13 +146,13 @@ def test_debug_compute_plan_aggregate_composite(network, client, debug_client, d @pytest.mark.remote_only -def test_debug_download_dataset(debug_client, default_dataset): - debug_client.download_opener(default_dataset.key) +def test_debug_download_dataset(hybrid_client, default_dataset): + hybrid_client.download_opener(default_dataset.key) @pytest.mark.remote_only @pytest.mark.slow -def test_test_data_traintuple(client, debug_client, debug_factory, default_dataset): +def test_test_data_traintuple(client, hybrid_client, debug_factory, default_dataset): """Check that we can't use test data samples for traintuples""" spec = debug_factory.create_algo(AlgoCategory.simple) algo = client.add_algo(spec) @@ -165,13 +165,13 @@ def test_test_data_traintuple(client, debug_client, debug_factory, default_datas ) with pytest.raises(InvalidRequest) as e: - debug_client.add_traintuple(spec) + hybrid_client.add_traintuple(spec) assert "Cannot create train task with test data" in str(e.value) @pytest.mark.remote_only @pytest.mark.slow -def test_fake_data_sample_key(client, debug_client, debug_factory, default_dataset): +def test_fake_data_sample_key(client, hybrid_client, debug_factory, default_dataset): """Check that a traintuple can't run with a fake train_data_sample_keys""" spec = debug_factory.create_algo(AlgoCategory.simple) algo = client.add_algo(spec) @@ -183,5 +183,5 @@ def test_fake_data_sample_key(client, debug_client, debug_factory, default_datas ) with pytest.raises(InvalidRequest) as e: - debug_client.add_traintuple(spec) + hybrid_client.add_traintuple(spec) assert "Could not get all the data_samples in the database with the given data_sample_keys" in str(e.value) diff --git a/tests/workflows/README.md b/tests/workflows/README.md index 329a77df..3f1d4190 100644 --- a/tests/workflows/README.md +++ b/tests/workflows/README.md @@ -9,4 +9,4 @@ mnist_workflow: ## run locally with docker or subprocess -`DEBUG_SPAWNER=subprocess pytest tests -v --durations=0 -m "workflows" --local` +`pytest tests -v --durations=0 -m "workflows" --subprocess` diff --git a/tests/workflows/mnist-fedavg/assets/aggregate_algo.py b/tests/workflows/mnist-fedavg/assets/aggregate_algo.py index 1979d848..336b52a8 100644 --- a/tests/workflows/mnist-fedavg/assets/aggregate_algo.py +++ b/tests/workflows/mnist-fedavg/assets/aggregate_algo.py @@ -55,7 +55,7 @@ def predict(self, inputs, outputs): X = inputs["X"] X = torch.FloatTensor(X) - model = self.load_model(inputs["model"]) + model = self.load_model(inputs["models"]) model.eval() # add the context manager to reduce computation overhead with torch.no_grad():