diff --git a/experiments/pfedhn/trainer.py b/experiments/pfedhn/trainer.py index 1197dc4..1964713 100644 --- a/experiments/pfedhn/trainer.py +++ b/experiments/pfedhn/trainer.py @@ -59,7 +59,7 @@ def evaluate(nodes: BaseNodes, num_nodes, hnet, net, criteria, device, split='te return results -def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int, +def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int, steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float, embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int, n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path, @@ -68,7 +68,7 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int, ############################### # init nodes, hnet, local net # ############################### - nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_user, + nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node, batch_size=bs) embed_dim = embed_dim @@ -254,7 +254,6 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int, "--data-name", type=str, default="cifar10", choices=['cifar10', 'cifar100'], help="dir path for MNIST dataset" ) parser.add_argument("--data-path", type=str, default="data", help="dir path for MNIST dataset") - parser.add_argument("--classes_per_user", type=int, default=2, help="N classes assigned to each user") parser.add_argument("--num-nodes", type=int, default=50, help="number of simulated nodes") ################################## @@ -296,10 +295,15 @@ def train(data_name: str, data_path: str, classes_per_user: int, num_nodes: int, device = get_device(gpus=args.gpu) + if args.data_name == 'cifar10': + args.classes_per_node = 2 + else: + args.classes_per_node = 10 + train( data_name=args.data_name, data_path=args.data_path, - classes_per_user=args.classes_per_user, + classes_per_node=args.classes_per_node, num_nodes=args.num_nodes, steps=args.num_steps, inner_steps=args.inner_steps, diff --git a/experiments/pfedhn_pc/trainer.py b/experiments/pfedhn_pc/trainer.py index 7eaf68a..1ef6b28 100644 --- a/experiments/pfedhn_pc/trainer.py +++ b/experiments/pfedhn_pc/trainer.py @@ -266,7 +266,6 @@ def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int, ) parser.add_argument("--data-path", type=str, default='/cortex/data/images', help='data path') parser.add_argument("--num-nodes", type=int, default=50) - parser.add_argument("--classes-per-node", type=int, default=2) ################################## # Optimization args # @@ -306,6 +305,11 @@ def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int, device = get_device(gpus=args.gpu) + if args.data_name == 'cifar10': + args.classes_per_node = 2 + else: + args.classes_per_node = 10 + train( data_name=args.data_name, data_path=args.data_path,