Skip to content

Commit

Permalink
Fix PyTorch MNIST example doc (#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Feb 20, 2023
1 parent 6b54d77 commit 651d05d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions doc/source/example-walkthrough-pytorch-mnist.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,16 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e
self.device = device
self.epochs = epochs
def get_weights(self) -> fl.common.Weights:
def get_weights(self) -> fl.common.NDArrays:
"""Get model weights as a list of NumPy ndarrays."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
def set_weights(self, weights: fl.common.Weights) -> None:
def set_weights(self, weights: fl.common.NDArrays) -> None:
"""Set model weights from a list of NumPy ndarrays.
Parameters
----------
weights: fl.common.Weights
weights: fl.common.NDArrays
Weights received by the server and set to local model
Expand All @@ -170,7 +170,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e
def get_parameters(self, config) -> fl.common.ParametersRes:
"""Encapsulates the weight into Flower Parameters """
weights: fl.common.Weights = self.get_weights()
weights: fl.common.NDArrays = self.get_weights()
parameters = fl.common.ndarrays_to_parameters(weights)
return fl.common.ParametersRes(parameters=parameters)
Expand All @@ -187,7 +187,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e
Set of variables containing the new set of weights and information the client.
"""
weights: fl.common.Weights = fl.common.parameters_to_ndarrays(ins.parameters)
weights: fl.common.NDArrays = fl.common.parameters_to_ndarrays(ins.parameters)
fit_begin = timeit.default_timer()
# Set model parameters/weights
Expand All @@ -199,7 +199,7 @@ Now, let's look closely into the :code:`PytorchMNISTClient` inside :code:`flwr_e
)
# Return the refined weights and the number of examples used for training
weights_prime: fl.common.Weights = self.get_weights()
weights_prime: fl.common.NDArrays = self.get_weights()
params_prime = fl.common.ndarrays_to_parameters(weights_prime)
fit_duration = timeit.default_timer() - fit_begin
return fl.common.FitRes(
Expand Down

0 comments on commit 651d05d

Please sign in to comment.