Skip to content

Commit

Permalink
Refactor provision for general use - Part 1 (NVIDIA#3092)
Browse files Browse the repository at this point in the history
* refactor provision for general use

* reformat

* address pr comments

* fix test case

* reorg file structure

* address pr comments
  • Loading branch information
yanchengnv authored Dec 5, 2024
1 parent fab6347 commit e594dcb
Show file tree
Hide file tree
Showing 24 changed files with 3,134 additions and 693 deletions.
4 changes: 3 additions & 1 deletion nvflare/apis/utils/format_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def name_check(name: str, entity_type: str):
if re.match(regex_pattern, name):
return False, "name={} passed on regex_pattern={} check".format(name, regex_pattern)
else:
return True, "name={} is ill-formatted based on regex_pattern={}".format(name, regex_pattern)
return True, "name={} is ill-formatted for entity_type={} based on regex_pattern={}".format(
name, entity_type, regex_pattern
)


def validate_class_methods_args(cls):
Expand Down
144 changes: 144 additions & 0 deletions nvflare/lighter/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class WorkDir:
WORKSPACE = "workspace"
WIP = "wip_dir"
STATE = "state_dir"
RESOURCES = "resources_dir"
CURRENT_PROD_DIR = "current_prod_dir"


class ParticipantType:
SERVER = "server"
CLIENT = "client"
ADMIN = "admin"
OVERSEER = "overseer"


class PropKey:
API_VERSION = "api_version"
NAME = "name"
DESCRIPTION = "description"
ROLE = "role"
HOST_NAMES = "host_names"
CONNECT_TO = "connect_to"
LISTENING_HOST = "listening_host"
DEFAULT_HOST = "default_host"
PROTOCOL = "protocol"
API_ROOT = "api_root"
PORT = "port"
OVERSEER_END_POINT = "overseer_end_point"
ADMIN_PORT = "admin_port"
FED_LEARN_PORT = "fed_learn_port"


class CtxKey(WorkDir, PropKey):
PROJECT = "__project__"
TEMPLATE = "__template__"
PROVISION_MODE = "__provision_model__"
LAST_PROD_STAGE = "last_prod_stage"
TEMPLATE_FILES = "template_files"
SERVER_NAME = "server_name"
ROOT_CERT = "root_cert"
ROOT_PRI_KEY = "root_pri_key"


class ProvisionMode:
POC = "poc"
NORMAL = "normal"


class AdminRole:
PROJECT_ADMIN = "project_admin"
ORG_ADMIN = "org_admin"
LEAD = "lead"
MEMBER = "member"


class OverseerRole:
SERVER = "server"
CLIENT = "client"
ADMIN = "admin"


class TemplateSectionKey:
START_SERVER_SH = "start_svr_sh"
START_CLIENT_SH = "start_cln_sh"
DOCKER_SERVER_SH = "docker_svr_sh"
DOCKER_CLIENT_SH = "docker_cln_sh"
DOCKER_ADMIN_SH = "docker_adm_sh"
GUNICORN_CONF_PY = "gunicorn_conf_py"
START_OVERSEER_SH = "start_ovsr_sh"
FED_SERVER = "fed_server"
FED_CLIENT = "fed_client"
SUB_START_SH = "sub_start_sh"
STOP_FL_SH = "stop_fl_sh"
LOG_CONFIG = "log_config"
LOCAL_SERVER_RESOURCES = "local_server_resources"
LOCAL_CLIENT_RESOURCES = "local_client_resources"
SAMPLE_PRIVACY = "sample_privacy"
DEFAULT_AUTHZ = "default_authz"
SERVER_README = "readme_fs"
CLIENT_README = "readme_fc"
ADMIN_README = "readme_am"
FL_ADMIN_SH = "fl_admin_sh"
FED_ADMIN = "fed_admin"
COMPOSE_YAML = "compose_yaml"
DOCKERFILE = "dockerfile"
HELM_CHART_CHART = "helm_chart_chart"
HELM_CHART_VALUES = "helm_chart_values"
HELM_CHART_SERVICE_OVERSEER = "helm_chart_service_overseer"
HELM_CHART_SERVICE_SERVER = "helm_chart_service_server"
HELM_CHART_DEPLOYMENT_OVERSEER = "helm_chart_deployment_overseer"
HELM_CHART_DEPLOYMENT_SERVER = "helm_chart_deployment_server"


class ProvFileName:
START_SH = "start.sh"
SUB_START_SH = "sub_start.sh"
PRIVILEGE_YML = "privilege.yml"
DOCKER_SH = "docker.sh"
GUNICORN_CONF_PY = "gunicorn.conf.py"
FED_SERVER_JSON = "fed_server.json"
FED_CLIENT_JSON = "fed_client.json"
STOP_FL_SH = "stop_fl.sh"
LOG_CONFIG_DEFAULT = "log.config.default"
RESOURCES_JSON_DEFAULT = "resources.json.default"
PRIVACY_JSON_SAMPLE = "privacy.json.sample"
AUTHORIZATION_JSON_DEFAULT = "authorization.json.default"
README_TXT = "readme.txt"
FED_ADMIN_JSON = "fed_admin.json"
FL_ADMIN_SH = "fl_admin.sh"
SIGNATURE_JSON = "signature.json"
COMPOSE_YAML = "compose.yaml"
ENV = ".env"
COMPOSE_BUILD_DIR = "nvflare_compose"
DOCKERFILE = "Dockerfile"
REQUIREMENTS_TXT = "requirements.txt"
SERVER_CONTEXT_TENSEAL = "server_context.tenseal"
CLIENT_CONTEXT_TENSEAL = "client_context.tenseal"
HELM_CHART_DIR = "nvflare_hc"
DEPLOYMENT_OVERSEER_YAML = "deployment_overseer.yaml"
SERVICE_OVERSEER_YAML = "service_overseer.yaml"
CHART_YAML = "Chart.yaml"
VALUES_YAML = "values.yaml"
HELM_CHART_TEMPLATES_DIR = "templates"


class CertFileBasename:
CLIENT = "client"
SERVER = "server"
OVERSEER = "overseer"
110 changes: 110 additions & 0 deletions nvflare/lighter/ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

import yaml

from nvflare.lighter import utils

from .constants import CtxKey, PropKey, ProvisionMode
from .entity import Entity, Project


class ProvisionContext(dict):
def __init__(self, workspace_root_dir: str, project: Project):
super().__init__()
self[CtxKey.WORKSPACE] = workspace_root_dir

wip_dir = os.path.join(workspace_root_dir, "wip")
state_dir = os.path.join(workspace_root_dir, "state")
resources_dir = os.path.join(workspace_root_dir, "resources")
self.update({CtxKey.WIP: wip_dir, CtxKey.STATE: state_dir, CtxKey.RESOURCES: resources_dir})
dirs = [workspace_root_dir, resources_dir, wip_dir, state_dir]
utils.make_dirs(dirs)

# set commonly used data into ctx
self[CtxKey.PROJECT] = project

server = project.get_server()
admin_port = server.get_prop(PropKey.ADMIN_PORT, 8003)
self[CtxKey.ADMIN_PORT] = admin_port
fed_learn_port = server.get_prop(PropKey.FED_LEARN_PORT, 8002)
self[CtxKey.FED_LEARN_PORT] = fed_learn_port
self[CtxKey.SERVER_NAME] = server.name

def get_project(self):
return self.get(CtxKey.PROJECT)

def set_template(self, template: dict):
self[CtxKey.TEMPLATE] = template

def get_template(self):
return self.get(CtxKey.TEMPLATE)

def get_template_section(self, section_key: str):
template = self.get_template()
if not template:
raise RuntimeError("template is not available")

section = template.get(section_key)
if not section:
raise RuntimeError(f"missing section {section} in template")

return section

def set_provision_mode(self, mode: str):
valid_modes = [ProvisionMode.POC, ProvisionMode.NORMAL]
if mode not in valid_modes:
raise ValueError(f"invalid provision mode {mode}: must be one of {valid_modes}")
self[CtxKey.PROVISION_MODE] = mode

def get_provision_mode(self):
return self.get(CtxKey.PROVISION_MODE)

def get_wip_dir(self):
return self.get(CtxKey.WIP)

def get_ws_dir(self, entity: Entity):
return os.path.join(self.get_wip_dir(), entity.name)

def get_kit_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "startup")

def get_transfer_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "transfer")

def get_local_dir(self, entity: Entity):
return os.path.join(self.get_ws_dir(entity), "local")

def get_state_dir(self):
return self.get(CtxKey.STATE)

def get_resources_dir(self):
return self.get(CtxKey.RESOURCES)

def get_workspace(self):
return self.get(CtxKey.WORKSPACE)

def yaml_load_template_section(self, section_key: str):
return yaml.safe_load(self.get_template_section(section_key))

def json_load_template_section(self, section_key: str):
return json.loads(self.get_template_section(section_key))

def build_from_template(self, dest_dir: str, temp_section: str, file_name, replacement=None, mode="t", exe=False):
section = self.get_template_section(temp_section)
if replacement:
section = utils.sh_replace(section, replacement)
utils.write(os.path.join(dest_dir, file_name), section, mode, exe=exe)
1 change: 0 additions & 1 deletion nvflare/lighter/dummy_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ builders:
- master_template.yml
- aws_template.yml
- azure_template.yml
- path: nvflare.lighter.impl.template.TemplateBuilder
- path: nvflare.lighter.impl.static_file.StaticFileBuilder
args:
# config_folder can be set to inform NVIDIA FLARE where to get configuration
Expand Down
Loading

0 comments on commit e594dcb

Please sign in to comment.