Skip to content

Commit

Permalink
Address a few PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA committed Jan 19, 2024
1 parent 73d45f9 commit 3907aa7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 72 deletions.
74 changes: 20 additions & 54 deletions nvflare/dashboard/application/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

lighter_folder = os.path.dirname(utils.__file__)
template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml"))
for csp in ["aws", "azure"]:
supported_csps = ["aws", "azure"]
for csp in supported_csps:
csp_template_file = os.path.join(lighter_folder, "impl", f"{csp}_template.yml")
if os.path.exists(csp_template_file):
template.update(utils.load_yaml(csp_template_file))
Expand Down Expand Up @@ -130,43 +131,20 @@ def gen_server(key, first_server=True):
)
utils._write_pki(type="server", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
if not project.ha_mode:
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
tplt.get_start_sh("azure", "server"),
{
"type": "server",
"docker_network": "--network host",
"cln_uid": "",
"server_name": entity.name,
"ORG": "",
},
),
"t",
exe=True,
)
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
tplt.get_start_sh("aws"),
{
"type": "server",
"inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log",
"cln_uid": "",
"server_name": entity.name,
"ORG": "",
},
),
"t",
exe=True,
)
for csp in supported_csps:
utils._write(
os.path.join(dest_dir, get_csp_start_script_name(csp)),
tplt.get_start_sh(csp=csp, type="server", entity=entity),
"t",
exe=True,
)
signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

# local folder creation
dest_dir = os.path.join(server_dir, "local")
os.mkdir(dest_dir)
utils._write_local(type="client", template=template)
utils._write_local(type="server", dest_dir=dest_dir, template=template)

# workspace folder file
utils._write(
Expand Down Expand Up @@ -232,31 +210,21 @@ def gen_client(key, id):
config=config,
)

utils._write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
tplt.get_start_sh("azure", "client"),
{"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
exe=True,
)
utils._write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
tplt.get_start_sh("aws"),
{"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
),
"t",
exe=True,
)
for csp in supported_csps:
utils._write(
os.path.join(dest_dir, get_csp_start_script_name(csp)),
tplt.get_start_sh(csp=csp, type="client", entity=entity),
"t",
exe=True,
)

signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

# local folder creation
dest_dir = os.path.join(client_dir, "local")
os.mkdir(dest_dir)
utils._write_local(type="client", template=template, capacity=client.capacity.capacity)
utils._write_local(type="client", dest_dir=dest_dir, template=template, capacity=client.capacity.capacity)

# workspace folder file
utils._write(
Expand Down Expand Up @@ -312,9 +280,7 @@ def gen_user(key, id):
"t",
exe=True,
)
utils._write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
utils._write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
utils._write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

Expand Down
69 changes: 52 additions & 17 deletions nvflare/lighter/tplt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.


from . import utils


class Template:
def __init__(self, template):
self.template = template
Expand All @@ -21,23 +24,48 @@ def __init__(self, template):
def get_cloud_script_header(self):
return self.template.get("cloud_script_header")

def get_start_sh(self, csp, type="server"):
if csp == "aws":
return self.get_cloud_script_header() + self.template.get("aws_start_sh")
elif csp == "azure":
if type == "server":
return (
self.get_cloud_script_header()
+ self.get_azure_start_svr_header_sh()
+ self.get_azure_start_common_sh()
)
elif type == "client":
return (
self.get_cloud_script_header()
+ self.get_azure_start_cln_header_sh()
+ self.get_azure_start_common_sh()
)
return ""
def get_azure_server_start_sh(self, entity):
tmp = self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh()
script = utils.sh_replace(
tmp,
{
"type": "server",
"docker_network": "--network host",
"cln_uid": "",
"server_name": entity.name,
"ORG": "",
},
)
return script

def get_aws_server_start_sh(self, entity):
tmp = self.get_cloud_script_header() + self.template.get("aws_start_sh")
script = utils.sh_replace(
tmp,
{
"type": "server",
"inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log",
"cln_uid": "",
"server_name": entity.name,
"ORG": "",
},
)
return script

def get_azure_client_start_sh(self, entity):
tmp = self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh()
script = utils.sh_replace(
tmp,
{"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org},
)
return script

def get_aws_client_start_sh(self, entity):
tmp = self.get_cloud_script_header() + self.template.get("aws_start_sh")
script = utils.sh_replace(
tmp, {"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org}
)
return script

def get_azure_start_svr_header_sh(self):
return self.template.get("azure_start_svr_header_sh")
Expand All @@ -56,3 +84,10 @@ def get_azure_svr_sh(self):

def get_azure_cln_sh(self):
return self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh()

def get_start_sh(self, csp, type, entity):
try:
func = getattr(self, f"get_{csp}_{type}_start_sh")
return func(entity)
except AttributeError:
return ""
2 changes: 1 addition & 1 deletion nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def _write_local(type, dest_dir, template, capacity=""):
template["default_authz"],
"t",
)
resources = json.loads(template["local_client_resources"])
if type == "client":
resources = json.loads(template["local_client_resources"])
for component in resources["components"]:
if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]:
component["args"] = json.loads(capacity)
Expand Down

0 comments on commit 3907aa7

Please sign in to comment.