Skip to content

Commit

Permalink
temp hold
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Jul 21, 2023
1 parent 1817a34 commit b805725
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 101 deletions.
76 changes: 0 additions & 76 deletions nvflare/client/cache.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .module import LightningModule as LightningModule
from .api import to_fl as to_fl
from .api import init as init
36 changes: 12 additions & 24 deletions nvflare/lightning/module.py → nvflare/client/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@
import nvflare.client as flare


def init():
config_file = "nvf_lightning.json"
config = {
"exchange_path": "./",
"exchange_format": "pytorch",
"params_type": "FULL"
}
with open(config_file, "w") as f:
json.dump(config, f)
flare.init(config=config_file)


def unflatten(global_weights):
"Unflattens the params from NVFlare."
result = {}
Expand Down Expand Up @@ -69,17 +81,6 @@ def on_train_end(self):
super().on_train_end()
print("\n *****nvflare****** on_train_end ********** \n")
self._fl_train_end()

def _fl_init(self):
config_file = "nvf_lightning.json"
config = {
"exchange_path": "./",
"exchange_format": "pytorch",
"params_type": "FULL"
}
with open(config_file, "w") as f:
json.dump(config, f)
flare.init(config=config_file)

def _fl_train_start(self):
print("ZZZZZ calling _fl_train_start ZZZZZ")
Expand All @@ -99,19 +100,6 @@ def _fl_train_end(self):
flare.submit_model(weights)
print("ZZZZZ ending _fl_train_end ZZZZZ")

@staticmethod
def fit_start(_func):
""" Decorator factory. """

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
self._fl_init()
return func(self, *args, **kwargs)

return wrapper
return decorator(_func)

@staticmethod
def train_start(_func):
""" Decorator factory. """
Expand Down

0 comments on commit b805725

Please sign in to comment.