Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA committed Jan 12, 2024
1 parent edf5165 commit e6899e8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
18 changes: 4 additions & 14 deletions nvflare/dashboard/application/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,10 @@ def gen_server(key, first_server=True):
_write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
if not project.ha_mode:
azure_start_svr_header_sh = tplt.get_azure_start_svr_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_svr_header_sh + azure_start_common_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "server"),
{
"type": "server",
"docker_network": "--network host",
Expand All @@ -173,12 +170,10 @@ def gen_server(key, first_server=True):
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{
"type": "server",
"inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log",
Expand Down Expand Up @@ -299,24 +294,19 @@ def gen_client(key, id):
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
azure_start_cln_header_sh = tplt.get_azure_start_cln_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_cln_header_sh + azure_start_common_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "client"),
{"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
Expand Down
13 changes: 11 additions & 2 deletions nvflare/lighter/tplt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
class Template:
def __init__(self, template):
self.template = template
self.supported_csps = ("azure", "aws")

def get_cloud_script_header(self):
return self.template.get("cloud_script_header")

def get_aws_start_sh(self):
return self.template.get("aws_start_sh")
def get_aws_start_sh(self, csp):
if csp in self.supported_csps:
return self.get_cloud_script_header() + self.template.get(f"{csp}_start_sh")
return ""

def get_azure_start_svr_header_sh(self):
return self.template.get("azure_start_svr_header_sh")
Expand All @@ -34,3 +37,9 @@ def get_azure_start_common_sh(self):

def get_sub_start_sh(self):
return self.template.get("sub_start_sh")

def get_azure_svr_sh(self):
return self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh()

def get_azure_cln_sh(self):
return self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh()

0 comments on commit e6899e8

Please sign in to comment.