Skip to content

Commit

Permalink
fixing classes_per_node according to data_name
Browse files Browse the repository at this point in the history
  • Loading branch information
AvivSham committed Feb 11, 2021
1 parent 278aea0 commit 07b53d6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
12 changes: 8 additions & 4 deletions experiments/pfedhn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def evaluate(nodes: BaseNodes, num_nodes, hnet, net, criteria, device, split='te
return results


def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int,
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float,
embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
Expand All @@ -68,7 +68,7 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int,
###############################
# init nodes, hnet, local net #
###############################
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_user,
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
batch_size=bs)

embed_dim = embed_dim
Expand Down Expand Up @@ -254,7 +254,6 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int,
"--data-name", type=str, default="cifar10", choices=['cifar10', 'cifar100'], help="dir path for MNIST dataset"
)
parser.add_argument("--data-path", type=str, default="data", help="dir path for MNIST dataset")
parser.add_argument("--classes_per_user", type=int, default=2, help="N classes assigned to each user")
parser.add_argument("--num-nodes", type=int, default=50, help="number of simulated nodes")

##################################
Expand Down Expand Up @@ -296,10 +295,15 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int,

device = get_device(gpus=args.gpu)

if args.data_name == 'cifar10':
args.classes_per_node = 2
else:
args.classes_per_node = 10

train(
data_name=args.data_name,
data_path=args.data_path,
classes_per_user=args.classes_per_user,
classes_per_node=args.classes_per_node,
num_nodes=args.num_nodes,
steps=args.num_steps,
inner_steps=args.inner_steps,
Expand Down
6 changes: 5 additions & 1 deletion experiments/pfedhn_pc/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
)
parser.add_argument("--data-path", type=str, default='/cortex/data/images', help='data path')
parser.add_argument("--num-nodes", type=int, default=50)
parser.add_argument("--classes-per-node", type=int, default=2)

##################################
# Optimization args #
Expand Down Expand Up @@ -306,6 +305,11 @@ def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,

device = get_device(gpus=args.gpu)

if args.data_name == 'cifar10':
args.classes_per_node = 2
else:
args.classes_per_node = 10

train(
data_name=args.data_name,
data_path=args.data_path,
Expand Down

0 comments on commit 07b53d6

Please sign in to comment.