Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA committed Jan 17, 2024
1 parent edf5165 commit 26773de
Show file tree
Hide file tree
Showing 12 changed files with 984 additions and 981 deletions.
187 changes: 46 additions & 141 deletions nvflare/dashboard/application/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,16 @@

lighter_folder = os.path.dirname(utils.__file__)
template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml"))
for csp in ["aws", "azure"]:
csp_template_file = os.path.join(lighter_folder, "impl", f"{csp}_template.yml")
if os.path.exists(csp_template_file):
template.update(utils.load_yaml(csp_template_file))


def get_csp_start_script_name(csp):
return f"{csp}_start.sh"


def _write(file_full_path, content, mode, exe=False):
mode = mode + "w"
with open(file_full_path, mode) as f:
f.write(content)
if exe:
os.chmod(file_full_path, 0o755)


def gen_overseer(key):
project = Project.query.first()
entity = Entity(project.overseer)
Expand All @@ -50,21 +46,19 @@ def gen_overseer(key):
dest_dir = os.path.join(overseer_dir, "startup")
os.mkdir(overseer_dir)
os.mkdir(dest_dir)
_write(
utils._write(
os.path.join(dest_dir, "start.sh"),
template["start_ovsr_sh"],
"t",
exe=True,
)
_write(
utils._write(
os.path.join(dest_dir, "gunicorn.conf.py"),
utils.sh_replace(template["gunicorn_conf_py"], {"port": "8443"}),
"t",
exe=False,
)
_write(os.path.join(dest_dir, "overseer.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "overseer.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="overseer", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."]
subprocess.run(run_args, cwd=tmp_dir)
fileobj = io.BytesIO()
Expand Down Expand Up @@ -118,50 +112,28 @@ def gen_server(key, first_server=True):
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": "",
"type": "server",
"cln_uid": ""
"cln_uid": "",
}
tplt = tplt_utils.Template(template)
with tempfile.TemporaryDirectory() as tmp_dir:
server_dir = os.path.join(tmp_dir, entity.name)
dest_dir = os.path.join(server_dir, "startup")
os.mkdir(server_dir)
os.mkdir(dest_dir)
_write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t")
_write(
os.path.join(dest_dir, "docker.sh"),
utils.sh_replace(template["docker_svr_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "start.sh"),
utils.sh_replace(template["start_svr_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "sub_start.sh"),
utils.sh_replace(tplt.get_sub_start_sh(), replacement_dict),
"t",
exe=True,
utils._write_common(
type="server",
dest_dir=dest_dir,
template=template,
tplt=tplt,
replacement_dict=replacement_dict,
config=config,
)
_write(
os.path.join(dest_dir, "stop_fl.sh"),
template["stop_fl_sh"],
"t",
exe=True,
)
_write(os.path.join(dest_dir, "server.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="server", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
if not project.ha_mode:
azure_start_svr_header_sh = tplt.get_azure_start_svr_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_svr_header_sh + azure_start_common_sh
_write(
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "server"),
{
"type": "server",
"docker_network": "--network host",
Expand All @@ -173,12 +145,10 @@ def gen_server(key, first_server=True):
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{
"type": "server",
"inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log",
Expand All @@ -196,29 +166,10 @@ def gen_server(key, first_server=True):
# local folder creation
dest_dir = os.path.join(server_dir, "local")
os.mkdir(dest_dir)
_write(
os.path.join(dest_dir, "log.config.default"),
template["log_config"],
"t",
)
_write(
os.path.join(dest_dir, "resources.json.default"),
template["local_server_resources"],
"t",
)
_write(
os.path.join(dest_dir, "privacy.json.sample"),
template["sample_privacy"],
"t",
)
_write(
os.path.join(dest_dir, "authorization.json.default"),
template["default_authz"],
"t",
)
utils._write_local(type="client", template=template)

# workspace folder file
_write(
utils._write(
os.path.join(server_dir, "readme.txt"),
template["readme_fs"],
"t",
Expand Down Expand Up @@ -249,7 +200,7 @@ def gen_client(key, id):
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": entity.org,
"type": "client",
"cln_uid": f"uid={entity.name}"
"cln_uid": f"uid={entity.name}",
}
if project.ha_mode:
overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"}
Expand All @@ -271,52 +222,29 @@ def gen_client(key, id):
os.mkdir(client_dir)
os.mkdir(dest_dir)

_write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t")
_write(
os.path.join(dest_dir, "docker.sh"),
utils.sh_replace(template["docker_cln_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "start.sh"),
template["start_cln_sh"],
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "sub_start.sh"),
utils.sh_replace(tplt.get_sub_start_sh(), replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "stop_fl.sh"),
template["stop_fl_sh"],
"t",
exe=True,
utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
utils._write_common(
type="client",
dest_dir=dest_dir,
template=template,
tplt=tplt,
replacement_dict=replacement_dict,
config=config,
)
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
azure_start_cln_header_sh = tplt.get_azure_start_cln_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_cln_header_sh + azure_start_common_sh
_write(

utils._write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "client"),
{"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
Expand All @@ -328,33 +256,10 @@ def gen_client(key, id):
# local folder creation
dest_dir = os.path.join(client_dir, "local")
os.mkdir(dest_dir)
_write(
os.path.join(dest_dir, "log.config.default"),
template["log_config"],
"t",
)
resources = json.loads(template["local_client_resources"])
for component in resources["components"]:
if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]:
component["args"] = json.loads(client.capacity.capacity)
break
_write(
os.path.join(dest_dir, "resources.json.default"),
json.dumps(resources, indent=2),
"t",
)
_write(
os.path.join(dest_dir, "privacy.json.sample"),
template["sample_privacy"],
"t",
)
_write(
os.path.join(dest_dir, "authorization.json.default"),
template["default_authz"],
"t",
)
utils._write_local(type="client", template=template, capacity=client.capacity.capacity)

# workspace folder file
_write(
utils._write(
os.path.join(client_dir, "readme.txt"),
template["readme_fc"],
"t",
Expand Down Expand Up @@ -400,16 +305,16 @@ def gen_user(key, id):
os.mkdir(user_dir)
os.mkdir(dest_dir)

_write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t")
_write(
utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t")
utils._write(
os.path.join(dest_dir, "fl_admin.sh"),
utils.sh_replace(template["fl_admin_sh"], replacement_dict),
"t",
exe=True,
)
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
utils._write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
utils._write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

Expand All @@ -418,12 +323,12 @@ def gen_user(key, id):
os.mkdir(dest_dir)

# workspace folder file
_write(
utils._write(
os.path.join(user_dir, "readme.txt"),
template["readme_am"],
"t",
)
_write(
utils._write(
os.path.join(user_dir, "system_info.ipynb"),
utils.sh_replace(template["adm_notebook"], replacement_dict),
"t",
Expand Down
3 changes: 1 addition & 2 deletions nvflare/dashboard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import docker
import nvflare
from nvflare.apis.utils.format_check import name_check
from nvflare.dashboard.application.blob import _write
from nvflare.lighter import tplt_utils, utils

supported_csp = ("azure", "aws")
Expand Down Expand Up @@ -146,7 +145,7 @@ def cloud(args):
dsb_start = template[f"{csp}_start_dsb_sh"]
version = nvflare.__version__
replacement_dict = {"NVFLARE": f"nvflare=={version}", "START_OPT": f"-i {args.image}" if args.image else ""}
_write(
utils._write(
dest,
utils.sh_replace(tplt.get_cloud_script_header() + dsb_start, replacement_dict),
"t",
Expand Down
5 changes: 4 additions & 1 deletion nvflare/lighter/dummy_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ participants:
builders:
- path: nvflare.lighter.impl.workspace.WorkspaceBuilder
args:
template_file: master_template.yml
template_file:
- master_template.yml
- aws_template.yml
- azure_template.yml
- path: nvflare.lighter.impl.template.TemplateBuilder
- path: nvflare.lighter.impl.static_file.StaticFileBuilder
args:
Expand Down
5 changes: 4 additions & 1 deletion nvflare/lighter/ha_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ participants:
builders:
- path: nvflare.lighter.impl.workspace.WorkspaceBuilder
args:
template_file: master_template.yml
template_file:
- master_template.yml
- aws_template.yml
- azure_template.yml
- path: nvflare.lighter.impl.template.TemplateBuilder
- path: nvflare.lighter.impl.docker.DockerBuilder
args:
Expand Down
Loading

0 comments on commit 26773de

Please sign in to comment.