From 28a41aa2573d04fc27d4c44ec1f3951963750176 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 11 Sep 2024 11:34:18 +0200 Subject: [PATCH] fix(framework) Adjust framework name in templates docstrings (#4127) --- src/py/flwr/cli/new/new.py | 59 +++++++++++++++----------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 520f683a47d..90e4970d592 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -136,36 +136,23 @@ def new( username = prompt_text("Please provide your Flower username") if framework is not None: - framework_str_upper = str(framework.value) + framework_str = str(framework.value) else: - framework_value = prompt_options( + framework_str = prompt_options( "Please select ML framework by typing in the number", [mlf.value for mlf in MlFramework], ) - selected_value = [ - name - for name, value in vars(MlFramework).items() - if value == framework_value - ] - framework_str_upper = selected_value[0] - - framework_str = framework_str_upper.lower() llm_challenge_str = None - if framework_str == "flowertune": + if framework_str == MlFramework.FLOWERTUNE: llm_challenge_value = prompt_options( "Please select LLM challenge by typing in the number", sorted([challenge.value for challenge in LlmChallengeName]), ) - selected_value = [ - name - for name, value in vars(LlmChallengeName).items() - if value == llm_challenge_value - ] - llm_challenge_str = selected_value[0] - llm_challenge_str = llm_challenge_str.lower() + llm_challenge_str = llm_challenge_value.lower() - is_baseline_project = framework_str == "baseline" + if framework_str == MlFramework.BASELINE: + framework_str = "baseline" print( typer.style( @@ -176,19 +163,21 @@ def new( ) context = { - "framework_str": framework_str_upper, + "framework_str": framework_str, "import_name": import_name.replace("-", "_"), "package_name": package_name, "project_name": app_name, "username": username, } + template_name = framework_str.lower() + # List of files to render if llm_challenge_str: files = { ".gitignore": {"template": "app/.gitignore.tpl"}, - "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, - "README.md": {"template": f"app/README.{framework_str}.md.tpl"}, + "pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"}, + "README.md": {"template": f"app/README.{template_name}.md.tpl"}, f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"}, f"{import_name}/server_app.py": { "template": "app/code/flwr_tune/server_app.py.tpl" @@ -235,44 +224,42 @@ def new( files = { ".gitignore": {"template": "app/.gitignore.tpl"}, "README.md": {"template": "app/README.md.tpl"}, - "pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"}, + "pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"}, f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"}, f"{import_name}/server_app.py": { - "template": f"app/code/server.{framework_str}.py.tpl" + "template": f"app/code/server.{template_name}.py.tpl" }, f"{import_name}/client_app.py": { - "template": f"app/code/client.{framework_str}.py.tpl" + "template": f"app/code/client.{template_name}.py.tpl" }, } # Depending on the framework, generate task.py file frameworks_with_tasks = [ - MlFramework.PYTORCH.value.lower(), - MlFramework.JAX.value.lower(), - MlFramework.HUGGINGFACE.value.lower(), - MlFramework.MLX.value.lower(), - MlFramework.TENSORFLOW.value.lower(), + MlFramework.PYTORCH.value, + MlFramework.JAX.value, + MlFramework.HUGGINGFACE.value, + MlFramework.MLX.value, + MlFramework.TENSORFLOW.value, ] if framework_str in frameworks_with_tasks: files[f"{import_name}/task.py"] = { - "template": f"app/code/task.{framework_str}.py.tpl" + "template": f"app/code/task.{template_name}.py.tpl" } - if is_baseline_project: + if framework_str == "baseline": # Include additional files for baseline template for file_name in ["model", "dataset", "strategy", "utils", "__init__"]: files[f"{import_name}/{file_name}.py"] = { - "template": f"app/code/{file_name}.{framework_str}.py.tpl" + "template": f"app/code/{file_name}.{template_name}.py.tpl" } # Replace README.md - files["README.md"]["template"] = f"app/README.{framework_str}.md.tpl" + files["README.md"]["template"] = f"app/README.{template_name}.md.tpl" # Add LICENSE files["LICENSE"] = {"template": "app/LICENSE.tpl"} - context["framework_str"] = "baseline" - for file_path, value in files.items(): render_and_create( file_path=project_dir / file_path,