From cd8bfc5c47665cd3d6865c01e2cd122cf7fd4d3d Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Fri, 12 Jan 2024 08:50:31 -0800 Subject: [PATCH] Address PR comments --- nvflare/dashboard/application/blob.py | 18 ++++-------------- nvflare/lighter/tplt_utils.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/nvflare/dashboard/application/blob.py b/nvflare/dashboard/application/blob.py index bd00211567..a24477bf97 100644 --- a/nvflare/dashboard/application/blob.py +++ b/nvflare/dashboard/application/blob.py @@ -155,13 +155,10 @@ def gen_server(key, first_server=True): _write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) if not project.ha_mode: - azure_start_svr_header_sh = tplt.get_azure_start_svr_header_sh() - azure_start_common_sh = tplt.get_azure_start_common_sh() - script = tplt.get_cloud_script_header() + azure_start_svr_header_sh + azure_start_common_sh _write( os.path.join(dest_dir, get_csp_start_script_name("azure")), utils.sh_replace( - script, + tplt.get_start_sh("azure", "server"), { "type": "server", "docker_network": "--network host", @@ -173,12 +170,10 @@ def gen_server(key, first_server=True): "t", exe=True, ) - aws_start_svr_cln_sh = tplt.get_aws_start_sh() - script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh _write( os.path.join(dest_dir, get_csp_start_script_name("aws")), utils.sh_replace( - script, + tplt.get_start_sh("aws"), { "type": "server", "inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log", @@ -299,24 +294,19 @@ def gen_client(key, id): _write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False) _write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False) _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) - azure_start_cln_header_sh = tplt.get_azure_start_cln_header_sh() - azure_start_common_sh = tplt.get_azure_start_common_sh() - script = tplt.get_cloud_script_header() + azure_start_cln_header_sh + azure_start_common_sh _write( os.path.join(dest_dir, get_csp_start_script_name("azure")), utils.sh_replace( - script, + tplt.get_start_sh("azure", "client"), {"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org}, ), "t", exe=True, ) - aws_start_svr_cln_sh = tplt.get_aws_start_sh() - script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh _write( os.path.join(dest_dir, get_csp_start_script_name("aws")), utils.sh_replace( - script, + tplt.get_start_sh("aws"), {"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org}, ), "t", diff --git a/nvflare/lighter/tplt_utils.py b/nvflare/lighter/tplt_utils.py index a1d1a6fbac..830f833840 100644 --- a/nvflare/lighter/tplt_utils.py +++ b/nvflare/lighter/tplt_utils.py @@ -16,12 +16,15 @@ class Template: def __init__(self, template): self.template = template + self.supported_csps = ("azure", "aws") def get_cloud_script_header(self): return self.template.get("cloud_script_header") - def get_aws_start_sh(self): - return self.template.get("aws_start_sh") + def get_aws_start_sh(self, csp): + if csp in self.supported_csps: + return self.get_cloud_script_header() + self.template.get(f"{csp}_start_sh") + return "" def get_azure_start_svr_header_sh(self): return self.template.get("azure_start_svr_header_sh") @@ -34,3 +37,9 @@ def get_azure_start_common_sh(self): def get_sub_start_sh(self): return self.template.get("sub_start_sh") + + def get_azure_svr_sh(self): + return self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh() + + def get_azure_cln_sh(self): + return self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh()