Skip to content

Commit

Permalink
Enable patch and build for nvflight (NVIDIA#2574)
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacYangSLA authored and MinghuiChen43 committed May 9, 2024
1 parent 8f187eb commit 7d133c9
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 0 deletions.
13 changes: 13 additions & 0 deletions nvflight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
79 changes: 79 additions & 0 deletions nvflight/build_wheel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os
import shutil
import subprocess

from prepare_setup import prepare_setup

import versioneer

versions = versioneer.get_versions()
if versions["error"]:
today = datetime.date.today().timetuple()
year = today[0] % 1000
month = today[1]
day = today[2]
version = f"2.3.9.dev{year:02d}{month:02d}{day:02d}"
else:
version = versions["version"]


def patch(setup_dir, patch_file):
file_dir_path = os.path.abspath(os.path.dirname(__file__))
cmd = ['git', 'apply', os.path.join(file_dir_path, patch_file)]
try:
subprocess.run(cmd, check=True, cwd=setup_dir)
except subprocess.CalledProcessError as e:
print(f"Error to patch prepared files {e}")
exit(1)

nvflight_setup_dir = "/tmp/nvflight_setup"
patch_file = "patch.diff"
# prepare
prepare_setup(nvflight_setup_dir)

patch(nvflight_setup_dir, patch_file)
# build wheel
dist_dir = os.path.join(nvflight_setup_dir, "dist")
if os.path.isdir(dist_dir):
shutil.rmtree(dist_dir)

env = os.environ.copy()
env['NVFL_VERSION'] = version

cmd_str = "python setup.py -v sdist bdist_wheel"
cmd = cmd_str.split(" ")
try:
subprocess.run(cmd, check=True, cwd=nvflight_setup_dir, env=env)
except subprocess.CalledProcessError as e:
print(f"Error: {e}")

results = []
for root, dirs, files in os.walk(dist_dir):
result = [os.path.join(root, f) for f in files if f.endswith(".whl")]
results.extend(result)

if not os.path.isdir("dist"):
os.makedirs("dist", exist_ok=True)

if len(results) == 1:
shutil.copy(results[0], os.path.join("dist", os.path.basename(results[0])))
else:
print(f"something is not right, wheel files = {results}")

print(f"Setup dir {nvflight_setup_dir}")
shutil.rmtree(nvflight_setup_dir)
27 changes: 27 additions & 0 deletions nvflight/patch.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
diff --git a/nvflare/client/__init__.py b/nvflare/client/__init__.py
index 8d668962..7bcb2978 100644
--- a/nvflare/client/__init__.py
+++ b/nvflare/client/__init__.py
@@ -15,22 +15,4 @@

# https://github.com/microsoft/pylance-release/issues/856

-from nvflare.apis.analytix import AnalyticsDataType as AnalyticsDataType
-from nvflare.app_common.abstract.fl_model import FLModel as FLModel
-from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType
-
-from .api import get_config as get_config
-from .api import get_job_id as get_job_id
-from .api import get_site_name as get_site_name
-from .api import init as init
-from .api import is_evaluate as is_evaluate
-from .api import is_running as is_running
-from .api import is_submit_model as is_submit_model
-from .api import is_train as is_train
-from .api import log as log
-from .api import receive as receive
-from .api import send as send
-from .api import system_info as system_info
-from .decorator import evaluate as evaluate
-from .decorator import train as train
from .ipc.ipc_agent import IPCAgent
186 changes: 186 additions & 0 deletions nvflight/prepare_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import shutil

exclude_extensions = [".md", ".rst", ".pyc", "__pycache__"]

nvflight_packages = {
"nvflare": {
"include": ["_version.py"],
"exclude": ["*"]
},
"nvflare/apis": {
"include": ["__init__.py", "fl_constant.py"],
"exclude": ["*"]
},
"nvflare/app_common": {
"include": ["__init__.py"],
"exclude": ["*"]
},
"nvflare/app_common/decomposers": {
"include": ["__init__.py", "numpy_decomposers.py"],
"exclude": ["*"]
},
"nvflare/client": {
"include": ["__init__.py"],
"exclude": ["*"]
},
"nvflare/client/ipc": {
"include": ["__init__.py", "defs.py", "ipc_agent.py"],
"exclude": ["*"]
},
"nvflare/fuel": {
"include": ["__init__.py"],
"exclude": ["*"]
},
"nvflare/fuel/common": {
"include": ["*"],
"exclude": []
},
"nvflare/fuel/f3": {
"include": ["__init__.py",
"comm_error.py",
"connection.py",
"endpoint.py",
"mpm.py",
"stats_pool.py",
"comm_config.py",
"communicator.py",
"message.py",
"stream_cell.py"
],
"exclude": ["*"]
},
"nvflare/fuel/f3/cellnet": {
"include": ["*"],
"exclude": []
},
"nvflare/fuel/f3/drivers": {
"include": ["*"],
"exclude": ["grpc", "aio_grpc_driver.py", "aio_http_driver.py", "grpc_driver.py"]
},
"nvflare/fuel/f3/sfm": {
"include": ["*"],
"exclude": []
},
"nvflare/fuel/f3/streaming": {
"include": ["*"],
"exclude": []
},
"nvflare/fuel/hci": {
"include": ["__init__.py", "security.py"],
"exclude": ["*"]
},
"nvflare/fuel/utils": {
"include": ["*"],
"exclude": ["fobs"]
},
"nvflare/fuel/utils/fobs": {
"include": ["*"],
"exclude": []
},
"nvflare/fuel/utils/fobs/decomposers": {
"include": ["*"],
"exclude": []
},
"nvflare/security": {
"include": ["__init__.py", "logging.py"],
"exclude": ["*"]
}
}


def should_exclude(str_value):
return any(str_value.endswith(ext) for ext in exclude_extensions)


def package_selected_files(package_info: dict):
if not package_info:
return
all_items = "*"
results = {}

for p, package_rule in package_info.items():
include = package_rule["include"]
exclude = package_rule["exclude"]
paths = []
for include_item in include:
item_path = os.path.join(p, include_item)
if all_items != include_item:
if all_items in exclude:
# excluded everything except for included items
if os.path.isfile(item_path) and not should_exclude(item_path):
paths.append(item_path)
elif include_item not in exclude:
paths.append(item_path)
else:
if all_items in exclude:
# excluded everything except for included items
if os.path.isfile(item_path):
paths.append(item_path)
else:
# include everything in the package except excluded items
for root, dirs, files in os.walk(p):
if should_exclude(root) or os.path.basename(root) in exclude:
continue

for f in files:
if not should_exclude(f) and f not in exclude:
paths.append(os.path.join(root, f))
results[p] = paths
return results


def create_empty_file(file_path):
try:
with open(file_path, 'w'):
pass # This block is intentionally left empty
except Exception as e:
print(f"Error creating empty file: {e}")


def copy_files(package_paths: dict, target_dir: str):
for p, paths in package_paths.items():
for src_path in paths:
dst_path = os.path.join(target_dir, src_path)
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy(src_path, dst_path)

for p in package_paths:
init_file_path = os.path.join(target_dir, p, "__init__.py")
if not os.path.isfile(init_file_path):
create_empty_file(init_file_path)


def prepare_setup(setup_dir: str):
if os.path.isdir(setup_dir):
shutil.rmtree(setup_dir)

os.makedirs(setup_dir, exist_ok=True)
nvflight_paths = package_selected_files(nvflight_packages)
copy_files(nvflight_paths, setup_dir)

src_files = [
"setup.cfg",
"README.md",
"LICENSE",
os.path.join("nvflight", "setup.py")
]

for src in src_files:
shutil.copy(src, os.path.join(setup_dir, os.path.basename(src)))

54 changes: 54 additions & 0 deletions nvflight/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

from setuptools import find_packages, setup

this_directory = os.path.abspath(os.path.dirname(__file__))

today = datetime.date.today().timetuple()
year = today[0] % 1000
month = today[1]
day = today[2]

release_package = find_packages(
where=".",
include=[
"*",
],
exclude=["tests", "tests.*"],
)

package_data = {"": ["*.yml", "*.config"], }

release = os.environ.get("NVFL_RELEASE")
version = os.environ.get("NVFL_VERSION")

if release == "1":
package_dir = {"nvflare": "nvflare"}
package_name = "nvflare-light"
else:
package_dir = {"nvflare": "nvflare"}
package_name = "nvflare-light-nightly"

setup(
name=package_name,
version=version,
package_dir=package_dir,
packages=release_package,
package_data=package_data,
include_package_data=True,
)

0 comments on commit 7d133c9

Please sign in to comment.