Skip to content

Commit

Permalink
Merge branch 'main' into update-hf-template
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Sep 11, 2024
2 parents 22c7be3 + 28a41aa commit a1d0c25
Showing 1 changed file with 23 additions and 36 deletions.
59 changes: 23 additions & 36 deletions src/py/flwr/cli/new/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,36 +136,23 @@ def new(
username = prompt_text("Please provide your Flower username")

if framework is not None:
framework_str_upper = str(framework.value)
framework_str = str(framework.value)
else:
framework_value = prompt_options(
framework_str = prompt_options(
"Please select ML framework by typing in the number",
[mlf.value for mlf in MlFramework],
)
selected_value = [
name
for name, value in vars(MlFramework).items()
if value == framework_value
]
framework_str_upper = selected_value[0]

framework_str = framework_str_upper.lower()

llm_challenge_str = None
if framework_str == "flowertune":
if framework_str == MlFramework.FLOWERTUNE:
llm_challenge_value = prompt_options(
"Please select LLM challenge by typing in the number",
sorted([challenge.value for challenge in LlmChallengeName]),
)
selected_value = [
name
for name, value in vars(LlmChallengeName).items()
if value == llm_challenge_value
]
llm_challenge_str = selected_value[0]
llm_challenge_str = llm_challenge_str.lower()
llm_challenge_str = llm_challenge_value.lower()

is_baseline_project = framework_str == "baseline"
if framework_str == MlFramework.BASELINE:
framework_str = "baseline"

print(
typer.style(
Expand All @@ -176,19 +163,21 @@ def new(
)

context = {
"framework_str": framework_str_upper,
"framework_str": framework_str,
"import_name": import_name.replace("-", "_"),
"package_name": package_name,
"project_name": app_name,
"username": username,
}

template_name = framework_str.lower()

# List of files to render
if llm_challenge_str:
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"README.md": {"template": f"app/README.{framework_str}.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
"README.md": {"template": f"app/README.{template_name}.md.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server_app.py": {
"template": "app/code/flwr_tune/server_app.py.tpl"
Expand Down Expand Up @@ -235,44 +224,42 @@ def new(
files = {
".gitignore": {"template": "app/.gitignore.tpl"},
"README.md": {"template": "app/README.md.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{framework_str}.toml.tpl"},
"pyproject.toml": {"template": f"app/pyproject.{template_name}.toml.tpl"},
f"{import_name}/__init__.py": {"template": "app/code/__init__.py.tpl"},
f"{import_name}/server_app.py": {
"template": f"app/code/server.{framework_str}.py.tpl"
"template": f"app/code/server.{template_name}.py.tpl"
},
f"{import_name}/client_app.py": {
"template": f"app/code/client.{framework_str}.py.tpl"
"template": f"app/code/client.{template_name}.py.tpl"
},
}

# Depending on the framework, generate task.py file
frameworks_with_tasks = [
MlFramework.PYTORCH.value.lower(),
MlFramework.JAX.value.lower(),
MlFramework.HUGGINGFACE.value.lower(),
MlFramework.MLX.value.lower(),
MlFramework.TENSORFLOW.value.lower(),
MlFramework.PYTORCH.value,
MlFramework.JAX.value,
MlFramework.HUGGINGFACE.value,
MlFramework.MLX.value,
MlFramework.TENSORFLOW.value,
]
if framework_str in frameworks_with_tasks:
files[f"{import_name}/task.py"] = {
"template": f"app/code/task.{framework_str}.py.tpl"
"template": f"app/code/task.{template_name}.py.tpl"
}

if is_baseline_project:
if framework_str == "baseline":
# Include additional files for baseline template
for file_name in ["model", "dataset", "strategy", "utils", "__init__"]:
files[f"{import_name}/{file_name}.py"] = {
"template": f"app/code/{file_name}.{framework_str}.py.tpl"
"template": f"app/code/{file_name}.{template_name}.py.tpl"
}

# Replace README.md
files["README.md"]["template"] = f"app/README.{framework_str}.md.tpl"
files["README.md"]["template"] = f"app/README.{template_name}.md.tpl"

# Add LICENSE
files["LICENSE"] = {"template": "app/LICENSE.tpl"}

context["framework_str"] = "baseline"

for file_path, value in files.items():
render_and_create(
file_path=project_dir / file_path,
Expand Down

0 comments on commit a1d0c25

Please sign in to comment.