Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AWS and Azure cloud scripts of server and client #2275

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 54 additions & 161 deletions nvflare/dashboard/application/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,17 @@

lighter_folder = os.path.dirname(utils.__file__)
template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml"))


def get_csp_template(csp, participant, template):
return template[f"{csp}_start_{participant}_sh"]
supported_csps = ["aws", "azure"]
for csp in supported_csps:
csp_template_file = os.path.join(lighter_folder, "impl", f"{csp}_template.yml")
if os.path.exists(csp_template_file):
template.update(utils.load_yaml(csp_template_file))


def get_csp_start_script_name(csp):
return f"{csp}_start.sh"


def _write(file_full_path, content, mode, exe=False):
mode = mode + "w"
with open(file_full_path, mode) as f:
f.write(content)
if exe:
os.chmod(file_full_path, 0o755)


def gen_overseer(key):
project = Project.query.first()
entity = Entity(project.overseer)
Expand All @@ -54,21 +47,19 @@ def gen_overseer(key):
dest_dir = os.path.join(overseer_dir, "startup")
os.mkdir(overseer_dir)
os.mkdir(dest_dir)
_write(
utils._write(
os.path.join(dest_dir, "start.sh"),
template["start_ovsr_sh"],
"t",
exe=True,
)
_write(
utils._write(
os.path.join(dest_dir, "gunicorn.conf.py"),
utils.sh_replace(template["gunicorn_conf_py"], {"port": "8443"}),
"t",
exe=False,
)
_write(os.path.join(dest_dir, "overseer.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "overseer.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="overseer", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."]
subprocess.run(run_args, cwd=tmp_dir)
fileobj = io.BytesIO()
Expand Down Expand Up @@ -121,89 +112,42 @@ def gen_server(key, first_server=True):
"ha_mode": "true" if project.ha_mode else "false",
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": "",
"type": "server",
"cln_uid": "",
}
tplt = tplt_utils.Template(template)
with tempfile.TemporaryDirectory() as tmp_dir:
server_dir = os.path.join(tmp_dir, entity.name)
dest_dir = os.path.join(server_dir, "startup")
os.mkdir(server_dir)
os.mkdir(dest_dir)
_write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t")
_write(
os.path.join(dest_dir, "docker.sh"),
utils.sh_replace(template["docker_svr_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "start.sh"),
utils.sh_replace(template["start_svr_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "sub_start.sh"),
utils.sh_replace(template["sub_start_svr_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "stop_fl.sh"),
template["stop_fl_sh"],
"t",
exe=True,
utils._write_common(
type="server",
dest_dir=dest_dir,
template=template,
tplt=tplt,
replacement_dict=replacement_dict,
config=config,
)
_write(os.path.join(dest_dir, "server.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="server", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
if not project.ha_mode:
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
tplt.get_cloud_script_header() + get_csp_template("azure", "svr", template),
{"server_name": entity.name, "ORG": ""},
),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
tplt.get_cloud_script_header() + get_csp_template("aws", "svr", template),
{"server_name": entity.name, "ORG": ""},
),
"t",
exe=True,
)
for csp in supported_csps:
utils._write(
os.path.join(dest_dir, get_csp_start_script_name(csp)),
tplt.get_start_sh(csp=csp, type="server", entity=entity),
"t",
exe=True,
)
signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

# local folder creation
dest_dir = os.path.join(server_dir, "local")
os.mkdir(dest_dir)
_write(
os.path.join(dest_dir, "log.config.default"),
template["log_config"],
"t",
)
_write(
os.path.join(dest_dir, "resources.json.default"),
template["local_server_resources"],
"t",
)
_write(
os.path.join(dest_dir, "privacy.json.sample"),
template["sample_privacy"],
"t",
)
_write(
os.path.join(dest_dir, "authorization.json.default"),
template["default_authz"],
"t",
)
utils._write_local(type="server", dest_dir=dest_dir, template=template)

# workspace folder file
_write(
utils._write(
os.path.join(server_dir, "readme.txt"),
template["readme_fs"],
"t",
Expand Down Expand Up @@ -233,6 +177,8 @@ def gen_client(key, id):
"config_folder": "config",
"docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare",
"org_name": entity.org,
"type": "client",
"cln_uid": f"uid={entity.name}",
}
if project.ha_mode:
overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"}
Expand All @@ -254,85 +200,34 @@ def gen_client(key, id):
os.mkdir(client_dir)
os.mkdir(dest_dir)

_write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t")
_write(
os.path.join(dest_dir, "docker.sh"),
utils.sh_replace(template["docker_cln_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "start.sh"),
template["start_cln_sh"],
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "sub_start.sh"),
utils.sh_replace(template["sub_start_cln_sh"], replacement_dict),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, "stop_fl.sh"),
template["stop_fl_sh"],
"t",
exe=True,
)
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
_write(
os.path.join(dest_dir, get_csp_start_script_name("azure")),
utils.sh_replace(
tplt.get_cloud_script_header() + get_csp_template("azure", "cln", template),
{"SITE": entity.name, "ORG": entity.org},
),
"t",
exe=True,
)
_write(
os.path.join(dest_dir, get_csp_start_script_name("aws")),
utils.sh_replace(
tplt.get_cloud_script_header() + get_csp_template("aws", "cln", template),
{"SITE": entity.name, "ORG": entity.org},
),
"t",
exe=True,
utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
utils._write_common(
type="client",
dest_dir=dest_dir,
template=template,
tplt=tplt,
replacement_dict=replacement_dict,
config=config,
)

for csp in supported_csps:
utils._write(
os.path.join(dest_dir, get_csp_start_script_name(csp)),
tplt.get_start_sh(csp=csp, type="client", entity=entity),
"t",
exe=True,
)

signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

# local folder creation
dest_dir = os.path.join(client_dir, "local")
os.mkdir(dest_dir)
_write(
os.path.join(dest_dir, "log.config.default"),
template["log_config"],
"t",
)
resources = json.loads(template["local_client_resources"])
for component in resources["components"]:
if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]:
component["args"] = json.loads(client.capacity.capacity)
break
_write(
os.path.join(dest_dir, "resources.json.default"),
json.dumps(resources, indent=2),
"t",
)
_write(
os.path.join(dest_dir, "privacy.json.sample"),
template["sample_privacy"],
"t",
)
_write(
os.path.join(dest_dir, "authorization.json.default"),
template["default_authz"],
"t",
)
utils._write_local(type="client", dest_dir=dest_dir, template=template, capacity=client.capacity.capacity)

# workspace folder file
_write(
utils._write(
os.path.join(client_dir, "readme.txt"),
template["readme_fc"],
"t",
Expand Down Expand Up @@ -378,16 +273,14 @@ def gen_user(key, id):
os.mkdir(user_dir)
os.mkdir(dest_dir)

_write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t")
_write(
utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t")
utils._write(
os.path.join(dest_dir, "fl_admin.sh"),
utils.sh_replace(template["fl_admin_sh"], replacement_dict),
"t",
exe=True,
)
_write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False)
_write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False)
_write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False)
utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert)
signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key))
json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt"))

Expand All @@ -396,12 +289,12 @@ def gen_user(key, id):
os.mkdir(dest_dir)

# workspace folder file
_write(
utils._write(
os.path.join(user_dir, "readme.txt"),
template["readme_am"],
"t",
)
_write(
utils._write(
os.path.join(user_dir, "system_info.ipynb"),
utils.sh_replace(template["adm_notebook"], replacement_dict),
"t",
Expand Down
3 changes: 1 addition & 2 deletions nvflare/dashboard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import docker
import nvflare
from nvflare.apis.utils.format_check import name_check
from nvflare.dashboard.application.blob import _write
from nvflare.lighter import tplt_utils, utils

supported_csp = ("azure", "aws")
Expand Down Expand Up @@ -146,7 +145,7 @@ def cloud(args):
dsb_start = template[f"{csp}_start_dsb_sh"]
version = nvflare.__version__
replacement_dict = {"NVFLARE": f"nvflare=={version}", "START_OPT": f"-i {args.image}" if args.image else ""}
_write(
utils._write(
dest,
utils.sh_replace(tplt.get_cloud_script_header() + dsb_start, replacement_dict),
"t",
Expand Down
5 changes: 4 additions & 1 deletion nvflare/lighter/dummy_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ participants:
builders:
- path: nvflare.lighter.impl.workspace.WorkspaceBuilder
args:
template_file: master_template.yml
template_file:
- master_template.yml
- aws_template.yml
- azure_template.yml
- path: nvflare.lighter.impl.template.TemplateBuilder
- path: nvflare.lighter.impl.static_file.StaticFileBuilder
args:
Expand Down
5 changes: 4 additions & 1 deletion nvflare/lighter/ha_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ participants:
builders:
- path: nvflare.lighter.impl.workspace.WorkspaceBuilder
args:
template_file: master_template.yml
template_file:
- master_template.yml
- aws_template.yml
- azure_template.yml
- path: nvflare.lighter.impl.template.TemplateBuilder
- path: nvflare.lighter.impl.docker.DockerBuilder
args:
Expand Down
Loading