-
Notifications
You must be signed in to change notification settings - Fork 881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
docs(framework) Add how-to guide for designing stateful ClientApp
objects
#4010
Conversation
ClietnApp
objectsClientApp
objects
Hi, I am not sure if I should create a new issue or comment here. I discussed with @jafermarq on how to implement a stateful client on Slack, and he gave me great tips and a link to this repository to provide my feedback on whether everything is working as expected. I followed the steps provided to store the client parameters in the context, like the following in the # Save the parameters after training to the context for the next round
p_record = ParametersRecord()
for k, v in self.net.state_dict().items():
# Convert to NumPy, then to Array. Add to record
p_record[k] = ndarray_to_array(v.cpu().numpy())
# Add to context
self.client_state.parameters_records["prev_round"] = p_record I loaded the context in # Here, I am hoping to get the previous state dict from the client_state
prev_state_dict = {}
# Extract record from context
p_record = self.client_state.parameters_records["prev_round"]
# Deserialize arrays
for k, v in p_record.items():
prev_state_dict[k] = torch.from_numpy(np.copy(basic_array_deserialization(v)))
prev_state_dict_list = segment_resnet_parameters(flatten_resnet_parameters(prev_state_dict)) Everything is working as expected, but I encountered the following warning:
I found a way to make it work using Thanks. |
Hi @Dawitkiros , sorry for the delay. Could you try this: from flwr.common import array_from_numpy
p_record = ParametersRecord()
for k, v in self.net.state_dict().items():
# Convert to NumPy, then to Array. Add to record
# p_record[k] = ndarray_to_array(v.cpu().numpy())
p_record[k] = array_from_numpy(v.cpu().numpy()) # use utility method
...
for k, v in p_record.items():
# prev_state_dict[k] = torch.from_numpy(np.copy(basic_array_deserialization(v)))
prev_state_dict[k] = torch.from_numpy(v.numpy()) # instead of the line above This shows how to do it with the built-int serialization from
|
Hi @jafermarq, Your suggestion works! |
amazing! thanks for confirming. If you have some suggestions on how to improve this doc page. We are very happy to incorporate your feedback. We might be merging this PR soon. |
Co-authored-by: Chong Shen Ng <chong.shen@flower.ai>
Co-authored-by: Heng Pan <pan@flower.ai>
Co-authored-by: Heng Pan <pan@flower.ai>
Co-authored-by: Heng Pan <pan@flower.ai>
Shows how to make use of the
Context
's state to store metrics and model parameters.Content formatted with https://github.com/dzhu/rstfmt