Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA committed Jan 12, 2024
1 parent edf5165 commit 7eba08b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
22 changes: 6 additions & 16 deletions nvflare/dashboard/application/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def gen_server(key, first_server=True):
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": "",
"type": "server",
"cln_uid": ""
"cln_uid": "",
}
tplt = tplt_utils.Template(template)
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -155,13 +155,10 @@ def gen_server(key, first_server=True):
_write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
if not project.ha_mode:
azure_start_svr_header_sh = tplt.get_azure_start_svr_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_svr_header_sh + azure_start_common_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "server"),
{
"type": "server",
"docker_network": "--network host",
Expand All @@ -173,12 +170,10 @@ def gen_server(key, first_server=True):
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{
"type": "server",
"inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log",
Expand Down Expand Up @@ -249,7 +244,7 @@ def gen_client(key, id):
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": entity.org,
"type": "client",
"cln_uid": f"uid={entity.name}"
"cln_uid": f"uid={entity.name}",
}
if project.ha_mode:
overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"}
Expand Down Expand Up @@ -299,24 +294,19 @@ def gen_client(key, id):
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
azure_start_cln_header_sh = tplt.get_azure_start_cln_header_sh()
azure_start_common_sh = tplt.get_azure_start_common_sh()
script = tplt.get_cloud_script_header() + azure_start_cln_header_sh + azure_start_common_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
script,
tplt.get_start_sh("azure", "client"),
{"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
exe=True,
)
aws_start_svr_cln_sh = tplt.get_aws_start_sh()
script = tplt.get_cloud_script_header() + aws_start_svr_cln_sh
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
script,
tplt.get_start_sh("aws"),
{"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
Expand Down
4 changes: 2 additions & 2 deletions nvflare/lighter/impl/static_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _build_server(self, server, ctx):
"docker_image": self.docker_image,
"org_name": server.org,
"type": "server",
"cln_uid": ""
"cln_uid": "",
}
if self.docker_image:
self._write(
Expand Down Expand Up @@ -233,7 +233,7 @@ def _build_client(self, client, ctx):
"docker_image": self.docker_image,
"org_name": client.org,
"type": "client",
"cln_uid": f"uid={client.subject}"
"cln_uid": f"uid={client.subject}",
}
if self.overseer_agent:
overseer_agent = copy.deepcopy(self.overseer_agent)
Expand Down
13 changes: 11 additions & 2 deletions nvflare/lighter/tplt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
class Template:
def __init__(self, template):
self.template = template
self.supported_csps = ("azure", "aws")

def get_cloud_script_header(self):
return self.template.get("cloud_script_header")

def get_aws_start_sh(self):
return self.template.get("aws_start_sh")
def get_aws_start_sh(self, csp):
if csp in self.supported_csps:
return self.get_cloud_script_header() + self.template.get(f"{csp}_start_sh")
return ""

def get_azure_start_svr_header_sh(self):
return self.template.get("azure_start_svr_header_sh")
Expand All @@ -34,3 +37,9 @@ def get_azure_start_common_sh(self):

def get_sub_start_sh(self):
return self.template.get("sub_start_sh")

def get_azure_svr_sh(self):
return self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh()

def get_azure_cln_sh(self):
return self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh()

0 comments on commit 7eba08b

Please sign in to comment.