Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Job config of FedJobConfig is not generated correctly #2935

Closed
KCC13 opened this issue Sep 11, 2024 · 3 comments · Fixed by #2955 or #2954
Closed

[BUG] Job config of FedJobConfig is not generated correctly #2935

KCC13 opened this issue Sep 11, 2024 · 3 comments · Fixed by #2955 or #2954
Assignees
Labels
bug Something isn't working

Comments

@KCC13
Copy link

KCC13 commented Sep 11, 2024

Describe the bug
Hi, as title, the bug was caused when I tried to replace the model (SimpleNetwork) in the hello-pt example with resnet18 of torchvision. The error message showed that the config was not correctly organized and thus could not be serialized.

To Reproduce

  1. Go to hello-pt_cifar10_fl.py, import resnet18 by adding from torchvision.models import resnet18, and replace model = SimpleNetwork() with model = resnet18(num_classes=10).
  2. Similarly, in fedavg_script_runner_pt.py, import resnet18 and replace initial_model=SimpleNetwork() with initial_model=resnet18(num_classes=10) in FedAvgJob.
  3. Run fedavg_script_runner_pt.py.
  4. See error

Expected behavior
The simulation should be executed correctly.

Screenshots
Image

Desktop (please complete the following information):

  • OS: MacOS 11.7.10
  • Python Version: 3.9.19
  • NVFlare Version: 2.5.0rc12

Additional context
If we look into the content of server_app in _get_server_app according to the indication of the error message, it shows:

{'format_version': 2, 'workflows': [{'id': 'controller', 'path': 'nvflare.app_common.workflows.fedavg.FedAvg', 'args': {'num_clients': 2, 'num_rounds': 2}}], 'components': [{'id': 'json_generator', 'path': 'nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator', 'args': {}}, {'id': 'model_selector', 'path': 'nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector', 'args': {'aggregation_weights': {}, 'key_metric': 'accuracy'}}, {'id': 'receiver', 'path': 'nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver', 'args': {'events': ['fed.analytix_log_stats']}}, {'id': 'persistor', 'path': 'nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor', 'args': {'model': {'path': 'torchvision.models.resnet.ResNet', 'args': {'norm_layer': }}}}, {'id': 'locator', 'path': 'nvflare.app_opt.pt.file_model_locator.PTFileModelLocator', 'args': {'pt_persistor_id': 'persistor'}}], 'task_data_filters': [], 'task_result_filters': []}

Several observations:

  • The error message was caused by {'norm_layer': <class 'torch.nn.modules.batchnorm.BatchNorm2d'>}.
  • The norm_layer argument of resnet18 takes classes (subclass of torch.nn.modules) as input, not instances. However, it seems the current FedJobConfig API cannot fulfill this kind of format/input requirement.
  • Our num_classes argument of resnet18 is not the same as the default value (1000), but it's not listed in the server_app.
@KCC13 KCC13 added the bug Something isn't working label Sep 11, 2024
@yhwen
Copy link
Collaborator

yhwen commented Sep 18, 2024

@KCC13 Thank you very much for using the job API functions, extending the usages and raising the issues to us. After examining your running use case, here's the causes for the errors you experienced:

  • The Job API only supports the object instance created directly out of the Python Class. It does not support the object instance creating through using the Python function. Like in your case, the "resnet18(num_classes=10)" creates an "ResNet" object instance out of the "resnet18" function. The job API can only use the "ResNet" object instance for generating the job config. It can not detect the object creating function logic in the "resnet18".

  • The job API only supports the object instance as the constructor parameter. It does not support the Python class as parameter directly. For example, in your use case, the "ResNet" is using "BatchNorm2d" class as a parameter, which is not supported right now. (We may extend to add this support in the future.)

  • The reason for the "num_classes" argument is not kept is because the "num_classes" is not set as an instance variable in the "ResNet" class constructor like the following. This has been described in our job API documentation.

    self.num_classes = num_classes

Note:
In order for the FedJob to use the values of arguments passed into the obj, the arguments must be set as instance variables of the same name (or prefixed with "_") in the constructor.

@KCC13
Copy link
Author

KCC13 commented Sep 20, 2024

Thank you very much for your clear reply. 🤜🤛

@KCC13 KCC13 closed this as completed Sep 20, 2024
@yhwen
Copy link
Collaborator

yhwen commented Sep 20, 2024

@KCC13 Here's an example how to wrap up the Python function into a Python class to work around this issue:

  1. Wrap up the resnet18 Python function into a Resnet18 class into the simple_network.py:
class Resnet18(ResNet):

    def __init__(self, num_classes, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any):
        self.num_classes = num_classes

        weights = ResNet18_Weights.verify(weights)

        if weights is not None:
            _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

        if weights is not None:
            super().load_state_dict(weights.get_state_dict(progress=progress))
  1. Go to hello-pt_cifar10_fl.py, Change the model = SimpleNetwork() with model = Resnet18(num_classes=10).
  2. Go to fedavg_script_runner_pt.py, replace initial_model=SimpleNetwork() with initial_model=Resnet18(num_classes=10) in FedAvgJob
  3. Run fedavg_script_runner_pt.py. Then it will work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
2 participants