Skip to content
This repository has been archived by the owner on Jul 22, 2024. It is now read-only.

Commit

Permalink
more cleaning and API changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hwang595 committed May 5, 2020
1 parent ce7d6cb commit 7e6cfff
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 2 additions & 0 deletions language_modeling/language_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""Utils for language models."""

import re
import numpy as np
import torch


# ------------------------
Expand Down
1 change: 1 addition & 0 deletions language_modeling/run_fedma_with_comm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python lstm_fedma_with_comm.py
14 changes: 8 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,8 @@ def local_retrain_fedavg(local_datasets, weights, args, device="cpu"):
train_dl_local = local_datasets[0]
test_dl_local = local_datasets[1]

optimizer_fine_tune = optim.Adam(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=0.001, weight_decay=0.0001, amsgrad=True)

#optimizer_fine_tune = optim.Adam(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=0.001, weight_decay=0.0001, amsgrad=True)
optimizer_fine_tune = optim.SGD(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=args.retrain_lr, momentum=0.9, weight_decay=0.0001)
criterion_fine_tune = nn.CrossEntropyLoss().to(device)

logger.info('n_training: %d' % len(train_dl_local))
Expand Down Expand Up @@ -887,8 +887,8 @@ def local_retrain_fedprox(local_datasets, weights, mu, args, device="cpu"):
train_dl_local = local_datasets[0]
test_dl_local = local_datasets[1]

optimizer_fine_tune = optim.Adam(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=0.001, weight_decay=0.0001, amsgrad=True)

#optimizer_fine_tune = optim.Adam(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=0.001, weight_decay=0.0001, amsgrad=True)
optimizer_fine_tune = optim.SGD(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=args.retrain_lr, momentum=0.9, weight_decay=0.0001)
criterion_fine_tune = nn.CrossEntropyLoss().to(device)

logger.info('n_training: {}'.format(len(train_dl_local)))
Expand Down Expand Up @@ -1506,7 +1506,8 @@ def fedma_comm(batch_weights, model_meta_data, layer_type, net_dataidx_map,
train_dl_global,
test_dl_global,
n_classes,
args)
device=device,
args=args)
batch_weights = [copy.deepcopy(hungarian_weights) for _ in range(args.n_nets)]
del hungarian_weights
del retrained_nets
Expand Down Expand Up @@ -1607,7 +1608,8 @@ def fedma_comm(batch_weights, model_meta_data, layer_type, net_dataidx_map,
train_dl_global,
test_dl_global,
n_classes,
args)
device=device,
args=args)

_ = compute_model_averaging_accuracy(models,
averaged_weights,
Expand Down
3 changes: 1 addition & 2 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ python main.py --model=moderate-cnn \
--comm_type=fedma \
--comm_round=10 \
--oneshot_matching= \
--retrain= \
--rematching=
--retrain=

0 comments on commit 7e6cfff

Please sign in to comment.