diff --git a/notebooks/api/0.8/11-container-images-k8s.ipynb b/notebooks/api/0.8/11-container-images-k8s.ipynb index ef21f427a26..7cca83448f4 100644 --- a/notebooks/api/0.8/11-container-images-k8s.ipynb +++ b/notebooks/api/0.8/11-container-images-k8s.ipynb @@ -493,8 +493,7 @@ "assert workerimage is not None, str([image.__dict__ for image in image_list])\n", "assert workerimage.is_built is not None, str(workerimage)\n", "assert workerimage.built_at is not None, str(workerimage)\n", - "assert workerimage.image_hash is not None, str(workerimage)\n", - "assert image_list[workerimage.built_image_tag] == workerimage" + "assert workerimage.image_hash is not None, str(workerimage)" ] }, { @@ -1037,10 +1036,6 @@ "assert workerimage_opendp.built_at is not None, str(workerimage_opendp.__dict__)\n", "assert workerimage_opendp.image_hash is not None, str(workerimage_opendp.__dict__)\n", "\n", - "assert _images[workerimage_opendp.built_image_tag] == workerimage_opendp, str(\n", - " workerimage_opendp\n", - ")\n", - "\n", "workerimage_opendp" ] }, @@ -1394,6 +1389,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "syft-3.11", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -1404,7 +1404,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb index 86cbdc836c6..3aabe924a4c 100644 --- a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb +++ b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb @@ -25,6 +25,7 @@ "\n", "# syft absolute\n", "import syft as sy\n", + "from syft.util.test_helpers.checkpoint import create_checkpoint\n", "from syft.util.test_helpers.email_helpers import get_email_server" ] }, @@ -87,6 +88,15 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "root_client.users" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -201,7 +211,7 @@ "metadata": {}, "outputs": [], "source": [ - "smtp_server.stop()" + "create_checkpoint(name=\"000-start-and-config\", client=root_client)" ] }, { @@ -210,7 +220,7 @@ "metadata": {}, "outputs": [], "source": [ - "server.land()" + "smtp_server.stop()" ] }, { @@ -218,12 +228,14 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "server.land()" + ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -237,7 +249,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb index a8299b5cdcd..65ad4ae6dde 100644 --- a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb +++ b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb @@ -22,9 +22,11 @@ "source": [ "# stdlib\n", "import os\n", + "from os import environ as env\n", "\n", "# syft absolute\n", "import syft as sy\n", + "from syft.util.test_helpers.checkpoint import load_from_checkpoint\n", "from syft.util.test_helpers.email_helpers import Timeout\n", "from syft.util.test_helpers.email_helpers import get_email_server" ] @@ -46,9 +48,21 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "3", "metadata": {}, + "outputs": [], + "source": [ + "# in case we are not in k8s we set them here for orchestra to use\n", + "env[\"DEFAULT_ROOT_EMAIL\"] = ROOT_EMAIL\n", + "env[\"DEFAULT_ROOT_PASSWORD\"] = ROOT_PASSWORD" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, "source": [ "### Launch server & login" ] @@ -56,7 +70,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -67,13 +81,29 @@ " port=\"8080\",\n", " n_consumers=num_workers, # How many workers to be spawned\n", " create_producer=True, # Can produce more workers\n", + " log_level=10,\n", ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "load_from_checkpoint(\n", + " name=\"000-start-and-config\",\n", + " client=server.client,\n", + " root_email=ROOT_EMAIL,\n", + " root_password=ROOT_PASSWORD,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -83,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -95,7 +125,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -115,7 +145,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "11", "metadata": {}, "source": [ "### Scale Worker pool" @@ -123,7 +153,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "12", "metadata": {}, "source": [ "##### Scale up" @@ -132,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -146,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -156,7 +186,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -176,7 +206,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "16", "metadata": {}, "source": [ "##### Scale down" @@ -185,7 +215,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -200,7 +230,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -219,7 +249,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -232,7 +262,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "20", "metadata": {}, "source": [ "#### Delete Worker Pool" @@ -241,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "21", "metadata": {}, "outputs": [], "source": [ @@ -254,7 +284,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -264,7 +294,7 @@ }, { "cell_type": "markdown", - "id": "21", + "id": "23", "metadata": {}, "source": [ "#### Re-launch the default worker pool" @@ -273,7 +303,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -283,7 +313,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -297,7 +327,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "27", "metadata": {}, "outputs": [], "source": [ @@ -321,25 +351,17 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "28", "metadata": {}, "outputs": [], "source": [ "server.land()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { "kernelspec": { - "display_name": "syft", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -353,7 +375,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.11.10" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb b/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb index 94aacf20397..b478a79f27f 100644 --- a/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb +++ b/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb @@ -388,7 +388,7 @@ " (\n", " image\n", " for image in dockerfile_list\n", - " if \"worker-bigquery\" in str(image.image_identifier)\n", + " if image.is_prebuilt and \"worker-bigquery\" in str(image.image_identifier)\n", " ),\n", " None,\n", ")\n", diff --git a/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/schema.py b/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/schema.py index d1ff1cf9b05..9ef82352e02 100644 --- a/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/schema.py +++ b/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/schema.py @@ -4,7 +4,9 @@ # syft absolute import syft as sy from syft import test_settings -from syft.rate_limiter import is_within_rate_limit + +# relative +from ..rate_limiter import is_within_rate_limit def make_schema(settings: dict, worker_pool_name: str) -> Callable: diff --git a/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/test_query.py b/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/test_query.py index ccd3c75b599..344879dcb62 100644 --- a/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/test_query.py +++ b/notebooks/scenarios/bigquery/upgradability/0.9.1_helpers/apis/live/test_query.py @@ -4,7 +4,9 @@ # syft absolute import syft as sy from syft import test_settings -from syft.rate_limiter import is_within_rate_limit + +# relative +from ..rate_limiter import is_within_rate_limit def make_test_query(settings) -> Callable: diff --git a/packages/grid/backend/grid/bootstrap.py b/packages/grid/backend/grid/bootstrap.py index 914411ec864..3da7eb9b600 100644 --- a/packages/grid/backend/grid/bootstrap.py +++ b/packages/grid/backend/grid/bootstrap.py @@ -121,9 +121,9 @@ def get_credential( # supplying a different key means something has gone wrong so raise Exception if ( - file_credential != env_credential - and file_credential is not None + file_credential is not None and env_credential is not None + and validation_func(file_credential) != validation_func(env_credential) ): raise Exception(f"{key} from ENV must match {key} in {CREDENTIALS_PATH}") diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index c6e4568afe7..fa674d94dcb 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -6,13 +6,13 @@ from syft.server.datasite import Datasite from syft.server.datasite import Server from syft.server.enclave import Enclave +from syft.server.env import get_default_bucket_name +from syft.server.env import get_enable_warnings +from syft.server.env import get_server_name +from syft.server.env import get_server_side_type +from syft.server.env import get_server_type +from syft.server.env import get_server_uid_env from syft.server.gateway import Gateway -from syft.server.server import get_default_bucket_name -from syft.server.server import get_enable_warnings -from syft.server.server import get_server_name -from syft.server.server import get_server_side_type -from syft.server.server import get_server_type -from syft.server.server import get_server_uid_env from syft.service.queue.zmq_client import ZMQClientConfig from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index bf796f8bc42..390fcd2b00c 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -42,6 +42,7 @@ images: target: "backend" context: ../ tags: + - dev-latest - dev-${DEVSPACE_TIMESTAMP} frontend: image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_FRONTEND}" @@ -52,6 +53,7 @@ images: context: ./frontend tags: - dev-${DEVSPACE_TIMESTAMP} + - dev-latest seaweedfs: image: "${CONTAINER_REGISTRY}/${DOCKER_IMAGE_SEAWEEDFS}" buildKit: @@ -60,6 +62,7 @@ images: context: ./seaweedfs tags: - dev-${DEVSPACE_TIMESTAMP} + - dev-latest # This is a list of `deployments` that DevSpace can create for this project deployments: @@ -73,6 +76,7 @@ deployments: global: registry: ${CONTAINER_REGISTRY} version: dev-${DEVSPACE_TIMESTAMP} + workerVersion: dev-latest # anything that does not need templating should go in helm/examples/dev/base.yaml # or profile specific values files valuesFiles: @@ -165,6 +169,7 @@ profiles: context: ./rathole tags: - dev-${DEVSPACE_TIMESTAMP} + - dev-latest # use rathole client-specific chart values - op: add path: deployments.syft.helm.valuesFiles @@ -185,6 +190,7 @@ profiles: context: ./rathole tags: - dev-${DEVSPACE_TIMESTAMP} + - dev-latest # enable rathole `devspace dev` config - op: add path: dev @@ -256,6 +262,7 @@ profiles: dockerfile: ./enclave/attestation/attestation.dockerfile context: ./enclave/attestation tags: + - dev-latest - dev-${DEVSPACE_TIMESTAMP} - op: add path: dev.backend.containers diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index bb126d3dcd3..67b0db4d452 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -70,7 +70,7 @@ spec: - name: INMEMORY_WORKERS value: {{ .Values.server.inMemoryWorkers | quote }} - name: DEFAULT_WORKER_POOL_IMAGE - value: "{{ .Values.global.registry }}/openmined/syft-backend:{{ .Values.global.version }}" + value: "{{ .Values.global.registry }}/openmined/syft-backend:{{ .Values.global.workerVersion }}" - name: DEFAULT_WORKER_POOL_COUNT value: {{ .Values.server.defaultWorkerPool.count | quote }} - name: DEFAULT_WORKER_POOL_POD_LABELS diff --git a/packages/syft/src/syft/abstract_server.py b/packages/syft/src/syft/abstract_server.py index 3b7885f0a0e..8c945f9abbb 100644 --- a/packages/syft/src/syft/abstract_server.py +++ b/packages/syft/src/syft/abstract_server.py @@ -6,6 +6,7 @@ # relative from .serde.serializable import serializable from .store.db.db import DBConfig +from .store.db.db import DBManager from .types.uid import UID if TYPE_CHECKING: @@ -43,6 +44,7 @@ class AbstractServer: in_memory_workers: bool services: "ServiceRegistry" db_config: DBConfig + db: DBManager[DBConfig] def get_service(self, path_or_func: str | Callable) -> "AbstractService": raise NotImplementedError diff --git a/packages/syft/src/syft/client/datasite_client.py b/packages/syft/src/syft/client/datasite_client.py index 0129b4a17ac..b88a7081299 100644 --- a/packages/syft/src/syft/client/datasite_client.py +++ b/packages/syft/src/syft/client/datasite_client.py @@ -417,7 +417,10 @@ def get_migration_data(self, include_blobs: bool = True) -> MigrationData: return res def load_migration_data( - self, path_or_data: str | Path | MigrationData + self, + path_or_data: str | Path | MigrationData, + include_worker_pools: bool = False, + with_reset_db: bool = False, ) -> SyftSuccess: if isinstance(path_or_data, MigrationData): migration_data = path_or_data @@ -437,7 +440,7 @@ def load_migration_data( public_message="Root verify key in migration data does not match this client's verify key" ) - if migration_data.includes_custom_workerpools: + if migration_data.includes_custom_workerpools and not include_worker_pools: prompt_warning_message( "This migration data includes custom workers, " "which need to be migrated separately with `sy.upgrade_custom_workerpools` " @@ -445,9 +448,15 @@ def load_migration_data( ) migration_data.migrate_and_upload_blobs() + migration_data = migration_data.copy_without_blobs() + + if not include_worker_pools: + migration_data = migration_data.copy_without_workerpools() - migration_data = migration_data.copy_without_workerpools().copy_without_blobs() - return self.api.services.migration.apply_migration_data(migration_data) + if with_reset_db: + return self.api.services.migration.reset_and_restore(migration_data) + else: + return self.api.services.migration.apply_migration_data(migration_data) def dump_state(self, path: str | Path) -> None: if isinstance(path, str): diff --git a/packages/syft/src/syft/server/env.py b/packages/syft/src/syft/server/env.py new file mode 100644 index 00000000000..c101f05bad1 --- /dev/null +++ b/packages/syft/src/syft/server/env.py @@ -0,0 +1,120 @@ +# stdlib +import json +import subprocess # nosec +import sys + +# relative +from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME +from ..types.uid import UID +from ..util.util import get_env +from ..util.util import str_to_bool + +SERVER_PRIVATE_KEY = "SERVER_PRIVATE_KEY" +SERVER_UID = "SERVER_UID" +SERVER_TYPE = "SERVER_TYPE" +SERVER_NAME = "SERVER_NAME" +SERVER_SIDE_TYPE = "SERVER_SIDE_TYPE" + +DEFAULT_ROOT_EMAIL = "DEFAULT_ROOT_EMAIL" +DEFAULT_ROOT_USERNAME = "DEFAULT_ROOT_USERNAME" +DEFAULT_ROOT_PASSWORD = "DEFAULT_ROOT_PASSWORD" # nosec + + +def get_private_key_env() -> str | None: + return get_env(SERVER_PRIVATE_KEY) + + +def get_server_type() -> str | None: + return get_env(SERVER_TYPE, "datasite") + + +def get_server_name() -> str | None: + return get_env(SERVER_NAME, None) + + +def get_server_side_type() -> str | None: + return get_env(SERVER_SIDE_TYPE, "high") + + +def get_server_uid_env() -> str | None: + return get_env(SERVER_UID) + + +def get_default_root_email() -> str | None: + return get_env(DEFAULT_ROOT_EMAIL, "info@openmined.org") + + +def get_default_root_username() -> str | None: + return get_env(DEFAULT_ROOT_USERNAME, "Jane Doe") + + +def get_default_root_password() -> str | None: + return get_env(DEFAULT_ROOT_PASSWORD, "changethis") # nosec + + +def get_enable_warnings() -> bool: + return str_to_bool(get_env("ENABLE_WARNINGS", "False")) + + +def get_container_host() -> str | None: + return get_env("CONTAINER_HOST") + + +def get_default_worker_image() -> str | None: + return get_env("DEFAULT_WORKER_POOL_IMAGE") + + +def get_default_worker_pool_name() -> str | None: + return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME) + + +def get_default_bucket_name() -> str: + env = get_env("DEFAULT_BUCKET_NAME") + server_id = get_server_uid_env() or "syft-bucket" + return env or server_id or "syft-bucket" + + +def get_default_worker_pool_pod_annotations() -> dict[str, str] | None: + annotations = get_env("DEFAULT_WORKER_POOL_POD_ANNOTATIONS", "null") + return json.loads(annotations) + + +def get_default_worker_pool_pod_labels() -> dict[str, str] | None: + labels = get_env("DEFAULT_WORKER_POOL_POD_LABELS", "null") + return json.loads(labels) + + +def in_kubernetes() -> bool: + return get_container_host() == "k8s" + + +def get_venv_packages() -> str: + try: + # subprocess call is safe because it uses a fully qualified path and fixed arguments + result = subprocess.run( + [sys.executable, "-m", "pip", "list", "--format=freeze"], # nosec + capture_output=True, + check=True, + text=True, + ) + return result.stdout + except subprocess.CalledProcessError as e: + return f"An error occurred: {e.stderr}" + + +def get_syft_worker() -> bool: + return str_to_bool(get_env("SYFT_WORKER", "false")) + + +def get_k8s_pod_name() -> str | None: + return get_env("K8S_POD_NAME") + + +def get_syft_worker_uid() -> str | None: + is_worker = get_syft_worker() + pod_name = get_k8s_pod_name() + uid = get_env("SYFT_WORKER_UID") + # if uid is empty is a K8S worker, generate a uid from the pod name + if (not uid) and is_worker and pod_name: + uid = str(UID.with_seed(pod_name)) + return uid diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 7ea7bed2b78..b94e98aac35 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -10,12 +10,9 @@ from datetime import timezone from functools import partial import hashlib -import json import logging import os from pathlib import Path -import subprocess # nosec -import sys import threading from time import sleep import traceback @@ -75,10 +72,9 @@ from ..service.service import UserServiceConfigRegistry from ..service.settings.settings import ServerSettings from ..service.settings.settings import ServerSettingsUpdate -from ..service.user.user import User -from ..service.user.user import UserCreate from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole +from ..service.user.utils import create_root_admin_if_not_exists from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image @@ -116,10 +112,20 @@ from ..util.util import get_env from ..util.util import get_queue_address from ..util.util import random_name -from ..util.util import str_to_bool from ..util.util import thread_ident from .credentials import SyftSigningKey from .credentials import SyftVerifyKey +from .env import get_default_root_email +from .env import get_default_root_password +from .env import get_default_root_username +from .env import get_default_worker_image +from .env import get_default_worker_pool_name +from .env import get_default_worker_pool_pod_annotations +from .env import get_default_worker_pool_pod_labels +from .env import get_private_key_env +from .env import get_server_uid_env +from .env import get_syft_worker_uid +from .env import in_kubernetes from .service_registry import ServiceRegistry from .utils import get_named_server_uid from .utils import get_temp_dir_for_server @@ -135,71 +141,6 @@ CODE_RELOADER: dict[int, Callable] = {} -SERVER_PRIVATE_KEY = "SERVER_PRIVATE_KEY" -SERVER_UID = "SERVER_UID" -SERVER_TYPE = "SERVER_TYPE" -SERVER_NAME = "SERVER_NAME" -SERVER_SIDE_TYPE = "SERVER_SIDE_TYPE" - -DEFAULT_ROOT_EMAIL = "DEFAULT_ROOT_EMAIL" -DEFAULT_ROOT_USERNAME = "DEFAULT_ROOT_USERNAME" -DEFAULT_ROOT_PASSWORD = "DEFAULT_ROOT_PASSWORD" # nosec - - -def get_private_key_env() -> str | None: - return get_env(SERVER_PRIVATE_KEY) - - -def get_server_type() -> str | None: - return get_env(SERVER_TYPE, "datasite") - - -def get_server_name() -> str | None: - return get_env(SERVER_NAME, None) - - -def get_server_side_type() -> str | None: - return get_env(SERVER_SIDE_TYPE, "high") - - -def get_server_uid_env() -> str | None: - return get_env(SERVER_UID) - - -def get_default_root_email() -> str | None: - return get_env(DEFAULT_ROOT_EMAIL, "info@openmined.org") - - -def get_default_root_username() -> str | None: - return get_env(DEFAULT_ROOT_USERNAME, "Jane Doe") - - -def get_default_root_password() -> str | None: - return get_env(DEFAULT_ROOT_PASSWORD, "changethis") # nosec - - -def get_enable_warnings() -> bool: - return str_to_bool(get_env("ENABLE_WARNINGS", "False")) - - -def get_container_host() -> str | None: - return get_env("CONTAINER_HOST") - - -def get_default_worker_image() -> str | None: - return get_env("DEFAULT_WORKER_POOL_IMAGE") - - -def get_default_worker_pool_name() -> str | None: - return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME) - - -def get_default_bucket_name() -> str: - env = get_env("DEFAULT_BUCKET_NAME") - server_id = get_server_uid_env() or "syft-bucket" - return env or server_id or "syft-bucket" - - def get_default_worker_pool_count(server: Server) -> int: return int( get_env( @@ -208,52 +149,6 @@ def get_default_worker_pool_count(server: Server) -> int: ) -def get_default_worker_pool_pod_annotations() -> dict[str, str] | None: - annotations = get_env("DEFAULT_WORKER_POOL_POD_ANNOTATIONS", "null") - return json.loads(annotations) - - -def get_default_worker_pool_pod_labels() -> dict[str, str] | None: - labels = get_env("DEFAULT_WORKER_POOL_POD_LABELS", "null") - return json.loads(labels) - - -def in_kubernetes() -> bool: - return get_container_host() == "k8s" - - -def get_venv_packages() -> str: - try: - # subprocess call is safe because it uses a fully qualified path and fixed arguments - result = subprocess.run( - [sys.executable, "-m", "pip", "list", "--format=freeze"], # nosec - capture_output=True, - check=True, - text=True, - ) - return result.stdout - except subprocess.CalledProcessError as e: - return f"An error occurred: {e.stderr}" - - -def get_syft_worker() -> bool: - return str_to_bool(get_env("SYFT_WORKER", "false")) - - -def get_k8s_pod_name() -> str | None: - return get_env("K8S_POD_NAME") - - -def get_syft_worker_uid() -> str | None: - is_worker = get_syft_worker() - pod_name = get_k8s_pod_name() - uid = get_env("SYFT_WORKER_UID") - # if uid is empty is a K8S worker, generate a uid from the pod name - if (not uid) and is_worker and pod_name: - uid = str(UID.with_seed(pod_name)) - return uid - - signing_key_env = get_private_key_env() server_uid_env = get_server_uid_env() @@ -1735,59 +1630,6 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings: ).unwrap() -def create_root_admin_if_not_exists( - name: str, - email: str, - password: str, - server: Server, -) -> User | None: - """ - If no root admin exists: - - all exists checks on the user stash will fail, as we cannot get the role for the admin to check if it exists - - result: a new admin is always created - - If a root admin exists with a different email: - - cause: DEFAULT_USER_EMAIL env variable is set to a different email than the root admin in the db - - verify_key_exists will return True - - result: no new admin is created, as the server already has a root admin - """ - user_stash = server.services.user.stash - - email_exists = user_stash.email_exists(email=email).unwrap() - if email_exists: - logger.debug("Admin not created, a user with this email already exists") - return None - - verify_key_exists = user_stash.verify_key_exists(server.verify_key).unwrap() - if verify_key_exists: - logger.debug("Admin not created, this server already has a root admin") - return None - - create_user = UserCreate( - name=name, - email=email, - password=password, - password_verify=password, - role=ServiceRole.ADMIN, - ) - - # New User Initialization - # 🟡 TODO: change later but for now this gives the main user super user automatically - user = create_user.to(User) - user.signing_key = server.signing_key - user.verify_key = server.verify_key - - new_user = user_stash.set( - credentials=server.verify_key, - obj=user, - ignore_duplicates=False, - ).unwrap() - - logger.debug(f"Created admin {new_user.email}") - - return new_user - - class ServerRegistry: __server_registry__: dict[UID, Server] = {} diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index eef1a113af7..62788762acf 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,5 +1,6 @@ # stdlib from collections import defaultdict +import logging # syft absolute import syft @@ -21,6 +22,7 @@ from ..action.action_permissions import StoragePermission from ..action.action_store import ActionObjectStash from ..context import AuthedServiceContext +from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method @@ -31,6 +33,8 @@ from .object_migration_state import SyftMigrationStateStash from .object_migration_state import SyftObjectMigrationState +logger = logging.getLogger(__name__) + @serializable(canonical_name="MigrationService", version=1) class MigrationService(AbstractService): @@ -260,6 +264,7 @@ def _create_migrated_objects( skip_check_type: bool = False, ) -> dict[type[SyftObject], list[SyftObject]]: created_objects: dict[type[SyftObject], list[SyftObject]] = {} + for key, objects in migrated_objects.items(): created_objects[key] = [] for migrated_object in objects: @@ -317,7 +322,7 @@ def _migrate_objects( latest_version = SyftObjectRegistry.get_latest_version(canonical_name) # Migrate data for objects in document store - print( + logger.info( f"Migrating data for: {canonical_name} table to version {latest_version}" ) for object in objects: @@ -463,3 +468,28 @@ def apply_migration_data( # apply metadata self._update_store_metadata(context, migration_data.metadata).unwrap() return SyftSuccess(message="Migration completed successfully") + + @service_method( + path="migration.reset_and_restore", + name="reset_and_restore", + roles=ADMIN_ROLE_LEVEL, + unwrap_on_success=False, + ) + def reset_and_restore( + self, + context: AuthedServiceContext, + migration_data: MigrationData, + ) -> SyftSuccess | SyftError: + try: + root_verify_key = context.server.verify_key + context.server.db.init_tables(reset=True) + context.credentials = root_verify_key + self.apply_migration_data(context, migration_data) + except Exception as e: + return SyftError.from_exception( + context=context, + exc=e, + include_traceback=True, + ) + + return SyftSuccess(message="Database reset successfully.") diff --git a/packages/syft/src/syft/service/user/utils.py b/packages/syft/src/syft/service/user/utils.py new file mode 100644 index 00000000000..191fc4fe181 --- /dev/null +++ b/packages/syft/src/syft/service/user/utils.py @@ -0,0 +1,63 @@ +# stdlib +import logging + +# relative +from ...abstract_server import AbstractServer +from .user import User +from .user import UserCreate +from .user_roles import ServiceRole + +logger = logging.getLogger(__name__) + + +def create_root_admin_if_not_exists( + name: str, + email: str, + password: str, + server: AbstractServer, +) -> User | None: + """ + If no root admin exists: + - all exists checks on the user stash will fail, as we cannot get the role for the admin to check if it exists + - result: a new admin is always created + + If a root admin exists with a different email: + - cause: DEFAULT_USER_EMAIL env variable is set to a different email than the root admin in the db + - verify_key_exists will return True + - result: no new admin is created, as the server already has a root admin + """ + user_stash = server.services.user.stash + + email_exists = user_stash.email_exists(email=email).unwrap() + if email_exists: + logger.debug("Admin not created, a user with this email already exists") + return None + + verify_key_exists = user_stash.verify_key_exists(server.verify_key).unwrap() + if verify_key_exists: + logger.debug("Admin not created, this server already has a root admin") + return None + + create_user = UserCreate( + name=name, + email=email, + password=password, + password_verify=password, + role=ServiceRole.ADMIN, + ) + + # New User Initialization + # 🟡 TODO: change later but for now this gives the main user super user automatically + user = create_user.to(User) + user.signing_key = server.signing_key + user.verify_key = server.verify_key + + new_user = user_stash.set( + credentials=server.verify_key, + obj=user, + ignore_duplicates=False, + ).unwrap() + + logger.debug(f"Created admin {new_user.email}") + + return new_user diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 83a30bb670b..f87e9818027 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -41,9 +41,10 @@ def add( except Exception as e: raise SyftException(public_message=f"Failed to create registry. {e}") - self.stash.set(context.credentials, registry).unwrap() + stored_registry = self.stash.set(context.credentials, registry).unwrap() return SyftSuccess( - message=f"Image Registry ID: {registry.id} created successfully" + message=f"Image Registry ID: {registry.id} created successfully", + value=stored_registry, ) @service_method( diff --git a/packages/syft/src/syft/service/worker/utils.py b/packages/syft/src/syft/service/worker/utils.py index 01d440879bd..d5967364593 100644 --- a/packages/syft/src/syft/service/worker/utils.py +++ b/packages/syft/src/syft/service/worker/utils.py @@ -251,6 +251,7 @@ def run_workers_in_threads( error = None worker_name = f"{pool_name}-{worker_count}" worker = SyftWorker( + id=UID.with_seed(worker_name), name=worker_name, status=WorkerStatus.RUNNING, worker_pool_name=pool_name, diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index a5f05f94dac..dd0795cf8df 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -89,6 +89,7 @@ def build( tag: str, registry_uid: UID | None = None, pull_image: bool = True, + force_build: bool = False, ) -> SyftSuccess: registry: SyftImageRegistry | None = None @@ -122,6 +123,7 @@ def build( and worker_image.image_identifier and worker_image.image_identifier.full_name_with_tag == image_identifier.full_name_with_tag + and not force_build ): raise SyftException( public_message=f"Image ID: {image_uid} is already built" @@ -192,18 +194,7 @@ def get_all(self, context: AuthedServiceContext) -> DictTuple[str, SyftWorkerIma One image one docker file for now """ images = self.stash.get_all(credentials=context.credentials).unwrap() - - res = {} - # if image is built, index it by full_name_with_tag - for im in images: - if im.is_built and im.image_identifier is not None: - res[im.image_identifier.full_name_with_tag] = im - # and then index all images by id - # TODO: jupyter repr needs to be updated to show unique values - # (even if multiple keys point to same value) - res.update({im.id.to_string(): im for im in images if not im.is_built}) - - return DictTuple(res) + return DictTuple({image.id.to_string(): image for image in images}) @service_method( path="worker_image.remove", diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 55b103ba369..3db794a2834 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -585,6 +585,32 @@ def delete( uid = worker_pool.id + self.purge_workers(context=context, pool_id=pool_id, pool_name=pool_name) + + self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap( + public_message=f"Failed to delete WorkerPool: {worker_pool.name} from stash" + ) + + return SyftSuccess(message=f"Successfully deleted worker pool with id {uid}") + + @service_method( + path="worker_pool.purge_workers", + name="purge_workers", + roles=DATA_OWNER_ROLE_LEVEL, + unwrap_on_success=False, + ) + def purge_workers( + self, + context: AuthedServiceContext, + pool_id: UID | None = None, + pool_name: str | None = None, + ) -> SyftSuccess: + worker_pool = self._get_worker_pool( + context, pool_id=pool_id, pool_name=pool_name + ).unwrap(public_message=f"Failed to get WorkerPool: {pool_id or pool_name}") + + uid = worker_pool.id + # relative from ..queue.queue_stash import Status @@ -614,9 +640,10 @@ def delete( if IN_KUBERNETES: # Scale the workers to zero - self.scale(context=context, number=0, pool_id=uid) runner = KubernetesRunner() - runner.delete_pool(pool_name=worker_pool.name) + if runner.exists(worker_pool.name): + self.scale(context=context, number=0, pool_id=uid) + runner.delete_pool(pool_name=worker_pool.name) else: workers = ( worker.resolve_with_context(context=context).unwrap() @@ -632,11 +659,19 @@ def delete( context=context, uid=id_, force=True ) - self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap( - public_message=f"Failed to delete WorkerPool: {worker_pool.name} from stash" + worker_pool.max_count = 0 + worker_pool.worker_list = [] + self.stash.update( + credentials=context.credentials, + obj=worker_pool, + ).unwrap( + public_message=( + f"Pool {worker_pool.name} was purged, " + f"but failed to update the stash" + ) ) - return SyftSuccess(message=f"Successfully deleted worker pool with id {uid}") + return SyftSuccess(message=f"Successfully Purged worker pool with id {uid}") @as_result(StashException, SyftException) def _get_worker_pool( diff --git a/packages/syft/src/syft/util/test_helpers/checkpoint.py b/packages/syft/src/syft/util/test_helpers/checkpoint.py new file mode 100644 index 00000000000..99b0a3b1e52 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/checkpoint.py @@ -0,0 +1,192 @@ +# stdlib +import datetime +import os +from pathlib import Path + +# syft absolute +from syft import SyftError +from syft import SyftException +from syft.client.client import SyftClient +from syft.service.user.user_roles import ServiceRole +from syft.util.util import get_root_data_path + +# relative +from ...server.env import get_default_root_email +from ...server.env import get_default_root_password +from .worker_helpers import build_and_push_image + +CHECKPOINT_ROOT = "checkpoints" +CHECKPOINT_DIR_PREFIX = "chkpt" + + +def root_checkpoint_path() -> Path: + return get_root_data_path() / CHECKPOINT_ROOT + + +def get_checkpoint_parent_dir(server_uid: str, chkpt_name: str) -> Path: + return root_checkpoint_path() / chkpt_name / server_uid + + +def create_checkpoint_dir(server_uid: str, chkpt_name: str) -> Path: + """Create a checkpoint directory by chkpt_name and server_uid.""" + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + checkpoint_dir = f"{CHECKPOINT_DIR_PREFIX}_{timestamp}" + checkpoint_parent_dir = get_checkpoint_parent_dir( + server_uid=server_uid, chkpt_name=chkpt_name + ) + checkpoint_full_path = checkpoint_parent_dir / checkpoint_dir + + # Format of Checkpoint Directory: + # /checkpoints/chkpt_name//chkpt_ + + checkpoint_full_path.mkdir(parents=True, exist_ok=True) + return checkpoint_full_path + + +def is_admin(client: SyftClient) -> bool: + return client._SyftClient__user_role == ServiceRole.ADMIN + + +def create_checkpoint( + name: str, # Name of the checkpoint + client: SyftClient, + root_email: str | None = None, + root_pwd: str | None = None, +) -> None: + """Save a checkpoint for the database.""" + + if root_email is None: + root_email = get_default_root_email() + + if root_pwd is None: + root_pwd = get_default_root_password() + + root_client = ( + client + if is_admin(client) + else client.login(email=root_email, password=root_pwd) + ) + migration_data = root_client.get_migration_data(include_blobs=True) + + if isinstance(migration_data, SyftError): + raise SyftException(message=migration_data.message) + + checkpoint_dir = create_checkpoint_dir( + server_uid=client.id.to_string(), chkpt_name=name + ) + migration_data.save( + path=checkpoint_dir / "migration.blob", + yaml_path=checkpoint_dir / "migration.yaml", + ) + print(f"Checkpoint saved at: \n {checkpoint_dir}") + + +def last_checkpoint_path_for(server_uid: str, chkpt_name: str) -> Path | None: + """Return the directory of the latest checkpoint for the given name.""" + + checkpoint_parent_dir = get_checkpoint_parent_dir( + server_uid=server_uid, chkpt_name=chkpt_name + ) + + checkpoint_dirs = [ + d + for d in checkpoint_parent_dir.glob(f"{CHECKPOINT_DIR_PREFIX}_*") + if d.is_dir() + ] + checkpoints_dirs_with_blob_entry = [ + d for d in checkpoint_dirs if any(d.glob("*.blob")) + ] + + if checkpoints_dirs_with_blob_entry: + print(f"Loading from the last checkpoint for: {chkpt_name}") + return max(checkpoints_dirs_with_blob_entry, key=lambda d: d.stat().st_mtime) + + return None + + +def get_registry_credentials() -> tuple[str, str]: + return os.environ.get("REGISTRY_USERNAME", ""), os.environ.get( + "REGISTRY_PASSWORD", "" + ) + + +def load_from_checkpoint( + client: SyftClient, + name: str, + root_email: str | None = None, + root_password: str | None = None, + registry_username: str | None = None, + registry_password: str | None = None, +) -> None: + """Load the last saved checkpoint for the given checkpoint state.""" + + root_email = "info@openmined.org" if root_email is None else root_email + root_password = "changethis" if root_password is None else root_password + + root_client = ( + client + if is_admin(client) + else client.login(email=root_email, password=root_password) + ) + latest_checkpoint_dir = last_checkpoint_path_for( + server_uid=client.id.to_string(), chkpt_name=name + ) + + if latest_checkpoint_dir is None: + print(f"No last checkpoint found for : {name}") + return + + print(f"Loading from checkpoint: {latest_checkpoint_dir}") + result = root_client.load_migration_data( + path_or_data=latest_checkpoint_dir / "migration.blob", + include_worker_pools=True, + with_reset_db=True, + ) + + if isinstance(result, SyftError): + raise SyftException(message=result.message) + + print("Successfully loaded data from checkpoint.") + + # Step 1: Build and push the worker images + + print("Recreating worker images from checkpoint.") + worker_image_list = ( + [] if root_client.images.get_all() is None else root_client.images.get_all() + ) + for worker_image in worker_image_list: + if worker_image.is_prebuilt: + continue + + registry = worker_image.image_identifier.registry + + build_and_push_image( + root_client, + worker_image, + registry_uid=registry.id if registry else None, + tag=worker_image.image_identifier.repo_with_tag, + reg_password=registry_username, + reg_username=registry_password, + force_build=True, + ) + + print("Successfully Built worker image data from checkpoint.") + + # Step 2: Recreate the worker pools + print("Recreating worker pools from checkpoint.") + worker_pool_list = ( + [] if root_client.worker_pools is None else root_client.worker_pools + ) + for worker_pool in worker_pool_list: + previous_worker_cnt = worker_pool.max_count + purge_res = root_client.worker_pools.purge_workers(pool_id=worker_pool.id) + print(purge_res) + add_res = root_client.worker_pools.add_workers( + number=previous_worker_cnt, + pool_id=worker_pool.id, + registry_username=registry_username, + registry_password=registry_password, + ) + print(add_res) + + print("Successfully loaded worker pool data from checkpoint.") diff --git a/packages/syft/src/syft/util/test_helpers/worker_helpers.py b/packages/syft/src/syft/util/test_helpers/worker_helpers.py index 3c2667fecc8..7a9d2f18842 100644 --- a/packages/syft/src/syft/util/test_helpers/worker_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/worker_helpers.py @@ -1,6 +1,12 @@ # syft absolute import syft as sy +# relative +from ...client.client import SyftClient +from ...service.response import SyftSuccess +from ...service.worker.worker_image import SyftWorkerImage +from ...types.uid import UID + def build_and_launch_worker_pool_from_docker_str( environment: str, @@ -84,3 +90,39 @@ def launch_worker_pool_from_docker_tag_and_registry( print(result) return launch_result + + +def prune_worker_pool_and_images(client: SyftClient) -> None: + for pool in client.worker_pools.get_all(): + client.worker_pools.delete(pool.id) + + for image in client.images.get_all(): + client.images.remove(image.id) + + +def build_and_push_image( + client: SyftClient, + image: SyftWorkerImage, + tag: str, + registry_uid: UID | None = None, + reg_username: str | None = None, + reg_password: str | None = None, + force_build: bool = False, +) -> None: + """Build and push the image to the given registry.""" + if image.is_prebuilt: + return + + build_result = client.api.services.worker_image.build( + image_uid=image.id, registry_uid=registry_uid, tag=tag, force_build=force_build + ) + print(build_result.message) + + if isinstance(build_result, SyftSuccess): + push_result = client.api.services.worker_image.push( + image.id, + username=reg_username, + password=reg_password, + ) + assert isinstance(push_result, SyftSuccess) # nosec: B101 + print(push_result.message) diff --git a/packages/syft/tests/syft/worker_pool/worker_pool_service_test.py b/packages/syft/tests/syft/worker_pool/worker_pool_service_test.py index f3306ab4fd4..a14cdef3f8b 100644 --- a/packages/syft/tests/syft/worker_pool/worker_pool_service_test.py +++ b/packages/syft/tests/syft/worker_pool/worker_pool_service_test.py @@ -30,7 +30,7 @@ 2, # total number of images. # 2 since we pull a pre-built image (1) as the base image to build a custom image (2) ), - (None, PrebuiltWorkerConfig(tag=PREBUILT_IMAGE_TAG), 1), + (None, PrebuiltWorkerConfig(tag=PREBUILT_IMAGE_TAG), 2), ] WORKER_CONFIG_TEST_CASES = [