-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fed prox adaptive mu #41
Conversation
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) | ||
|
||
# Set the Proximal Loss weight mu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is a little stale with the changes below. I would suggest updating it.
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) | ||
|
||
# Set the Proximal Loss weight mu | ||
self.adaptive_proximal_weight = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest that we make this a part of the configuration and an argument to the constructor of this client that is set to false by default (overridden by the config if included) rather than hard-coding it here.
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) | ||
|
||
# Set the Proximal Loss weight mu | ||
self.adaptive_proximal_weight = True | ||
if self.adaptive_proximal_weight is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be contracted to
if self.adaptive_proximal_weight:
so long as the typing of adaptive weight is a bool.
if self.adaptive_proximal_weight is True: | ||
self.proximal_weight_patience = 5 | ||
self.proximal_weight_change_value = 0.1 | ||
# If sampler is generating non iid data, then set the proximal weight to 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, why are we setting the proximal weight to 0.0 as an initial value if non-iid data is generated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was written in the paper:
" One of the key parameters of FedProx is μ. We provide the complete results of a simple heuristic of adaptively setting μ on four synthetic datasets in Figure 11. For the IID dataset (Synthetic-IID), μ starts from 1, and for the other non-IID datasets, μ starts from 0. Such initialization is adversarial to our methods. We decrease μ by 0.1 when the loss continues to decrease for 5 rounds and increase μ by 0.1 when we see the loss increase. This heuristic allows for competitive performance. It could also alleviate the potential issue that μ > 0 might slow down convergence on IID data, which rarely occurs in real federated settings."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. Do you mind adding a comment like "Following the setup in Appendix C3.3 of the FedProx paper" just to point to the rationale?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
if self.adaptive_proximal_weight: | ||
if loss <= previous_loss: | ||
self.proximal_weight_patience_counter += 1 | ||
if self.proximal_weight_patience_counter == self.proximal_weight_patience: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For safety I would have
if self.proximal_weight_patience_counter >= self.proximal_weight_patience:
and maybe a warning if the counter is strictly larger than self.proximal_weight_patience
. That way you won't have silent issues if the counter somehow ends up above the patience limit
self.proximal_weight_patience_counter += 1 | ||
if self.proximal_weight_patience_counter == self.proximal_weight_patience: | ||
self.proximal_weight -= self.proximal_weight_change_value | ||
if self.proximal_weight < 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than having another if statement here, I would suggest using
self.proximal_weight = max(0.0, self.proximal_weight)
self.proximal_weight: float = 0.1 | ||
self.proximal_weight_patience: int = 5 | ||
self.proximal_weight_change_value: float = 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a suggestion for a more compact name, if you don't like it this is fine, but maybe self.proximal_weight_delta
@@ -76,7 +94,12 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict | |||
local_epochs = self.narrow_config_type(config, "local_epochs", int) | |||
current_server_round = self.narrow_config_type(config, "current_server_round", int) | |||
# Currently uses training by epoch. | |||
metric_values = self.train_by_epochs(current_server_round, local_epochs, meter) | |||
metric_values, total_loss = self.train_by_epochs(current_server_round, local_epochs, meter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support train_by_steps
in the FedProxClient as well, so I think that training loop needs to be updated with these changes as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to update but I saw it got used in tests folder, so I didn't want to mess up. But I'll update it.
Based on reading through the paper and the issue discussion here my interpretation is that the server is updating that value of mu used by each client (they all have the same mu, whether it is adapted or not). The loss that is being tracked is the global loss rather than the local loss on each client where they update their own mu values. It is unclear in the paper whether the loss includes or does not include the proximal loss term. However, in their official implementation, here for example. It looks like they are just using the aggregated cross-entropy loss rather than the combined loss. It possible that having the clients learn their own mu would be useful, but the official implementation in FedProx is a globally adapted mu. So I think going that route makes the most sense for this implementation. |
Due to the comment above, the function |
PR Type
Feature
Short Description
I added adaptive mu feature to the FedProx experiments.