Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fed prox adaptive mu #41

Closed
wants to merge 4 commits into from
Closed

Fed prox adaptive mu #41

wants to merge 4 commits into from

Conversation

sanaAyrml
Copy link
Collaborator

PR Type

Feature

Short Description

I added adaptive mu feature to the FedProx experiments.

@sanaAyrml sanaAyrml closed this Jul 17, 2023
@sanaAyrml sanaAyrml reopened this Jul 17, 2023
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)

# Set the Proximal Loss weight mu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a little stale with the changes below. I would suggest updating it.

sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)

# Set the Proximal Loss weight mu
self.adaptive_proximal_weight = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that we make this a part of the configuration and an argument to the constructor of this client that is set to false by default (overridden by the config if included) rather than hard-coding it here.

sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1)

# Set the Proximal Loss weight mu
self.adaptive_proximal_weight = True
if self.adaptive_proximal_weight is True:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be contracted to

if self.adaptive_proximal_weight:

so long as the typing of adaptive weight is a bool.

if self.adaptive_proximal_weight is True:
self.proximal_weight_patience = 5
self.proximal_weight_change_value = 0.1
# If sampler is generating non iid data, then set the proximal weight to 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why are we setting the proximal weight to 0.0 as an initial value if non-iid data is generated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was written in the paper:
" One of the key parameters of FedProx is μ. We provide the complete results of a simple heuristic of adaptively setting μ on four synthetic datasets in Figure 11. For the IID dataset (Synthetic-IID), μ starts from 1, and for the other non-IID datasets, μ starts from 0. Such initialization is adversarial to our methods. We decrease μ by 0.1 when the loss continues to decrease for 5 rounds and increase μ by 0.1 when we see the loss increase. This heuristic allows for competitive performance. It could also alleviate the potential issue that μ > 0 might slow down convergence on IID data, which rarely occurs in real federated settings."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Do you mind adding a comment like "Following the setup in Appendix C3.3 of the FedProx paper" just to point to the rationale?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

if self.adaptive_proximal_weight:
if loss <= previous_loss:
self.proximal_weight_patience_counter += 1
if self.proximal_weight_patience_counter == self.proximal_weight_patience:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety I would have

if self.proximal_weight_patience_counter >= self.proximal_weight_patience:

and maybe a warning if the counter is strictly larger than self.proximal_weight_patience. That way you won't have silent issues if the counter somehow ends up above the patience limit

self.proximal_weight_patience_counter += 1
if self.proximal_weight_patience_counter == self.proximal_weight_patience:
self.proximal_weight -= self.proximal_weight_change_value
if self.proximal_weight < 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having another if statement here, I would suggest using

self.proximal_weight = max(0.0, self.proximal_weight)

self.proximal_weight: float = 0.1
self.proximal_weight_patience: int = 5
self.proximal_weight_change_value: float = 0.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a suggestion for a more compact name, if you don't like it this is fine, but maybe self.proximal_weight_delta

@@ -76,7 +94,12 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
local_epochs = self.narrow_config_type(config, "local_epochs", int)
current_server_round = self.narrow_config_type(config, "current_server_round", int)
# Currently uses training by epoch.
metric_values = self.train_by_epochs(current_server_round, local_epochs, meter)
metric_values, total_loss = self.train_by_epochs(current_server_round, local_epochs, meter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support train_by_steps in the FedProxClient as well, so I think that training loop needs to be updated with these changes as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to update but I saw it got used in tests folder, so I didn't want to mess up. But I'll update it.

@emersodb
Copy link
Collaborator

Based on reading through the paper and the issue discussion here my interpretation is that the server is updating that value of mu used by each client (they all have the same mu, whether it is adapted or not). The loss that is being tracked is the global loss rather than the local loss on each client where they update their own mu values. It is unclear in the paper whether the loss includes or does not include the proximal loss term. However, in their official implementation, here for example. It looks like they are just using the aggregated cross-entropy loss rather than the combined loss.

It possible that having the clients learn their own mu would be useful, but the official implementation in FedProx is a globally adapted mu. So I think going that route makes the most sense for this implementation.

@emersodb
Copy link
Collaborator

Due to the comment above, the function _maybe_update_proximal_weight_param will likely be relocated to the server side. Wherever it ends up, it would be good for it to have a set of tests associated with it to make sure the adjustments are behaving as expected for various loss trajectories.

@sanaAyrml sanaAyrml closed this Jul 25, 2023
@sanaAyrml sanaAyrml deleted the FedProx-Adaptive-Mu branch July 25, 2023 16:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants