Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cifar10 exmaple change model to Resnet, Evaluate accuracy is very low. #60

Open
shou123 opened this issue Dec 8, 2023 · 0 comments
Open

Comments

@shou123
Copy link

shou123 commented Dec 8, 2023

❓ Questions and Help

Until we move the questions to another medium, feel free to use this to submit your question:

Question

I run the FLsim cifar10 example, change the model from simple CNN to Resnet, and run the same dataset. The evaluate accuracy is very low. The report as:

Train finished Global Round: 2
(round = 2, epoch = 1, global round = 2), Loss/Training: 1.919056011840796
(round = 2, epoch = 1, global round = 2), Accuracy/Training: 29.19
(round = 2, epoch = 1, global round = 2), Loss/Aggregation: 2.3162186018220936
(round = 2, epoch = 1, global round = 2), Accuracy/Aggregation: 12.362
(round = 2, epoch = 1, global round = 2): Evaluates global model on all data of eval users
(round = 2, epoch = 1, global round = 2), Loss/Eval: 2.315063210050012
(round = 2, epoch = 1, global round = 2), Accuracy/Eval: 12.2
Current eval accuracy: {'Accuracy': 12.2}%, Best so far: {'Accuracy': 10.01}%

IMAGE_SIZE = 32


def build_data_provider(local_batch_size, examples_per_user, drop_last: bool = False):

    #============================================iid===============================================================
    transform = transforms.Compose(
        [
            transforms.Resize(IMAGE_SIZE),
            transforms.CenterCrop(IMAGE_SIZE),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    )
    train_dataset = CIFAR10(
        root="/home/shiyue/FLsim/cifar10", train=True, download=True, transform=transform
    )
    test_dataset = CIFAR10(
        root="/home/shiyue/FLsim/cifar10", train=False, download=True, transform=transform
    )
    sharder = SequentialSharder(examples_per_shard=examples_per_user)
    fl_data_loader = DataLoader(train_dataset, test_dataset, test_dataset, sharder, local_batch_size, drop_last)
    

    data_provider = DataProvider(fl_data_loader)
    return data_provider


def main(
    trainer_config,
    data_config,
    use_cuda_if_available: bool = True,
) -> None:
    cuda_enabled = torch.cuda.is_available() and use_cuda_if_available
    device = torch.device(f"cuda:{0}" if cuda_enabled else "cpu")
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)


    # pyre-fixme[6]: Expected `Optional[str]` for 2nd param but got `device`.
    global_model = FLModel(model, device)
    if cuda_enabled:
        global_model.fl_cuda()
    trainer = instantiate(trainer_config, model=global_model, cuda_enabled=cuda_enabled)
    data_provider = build_data_provider(
        local_batch_size=data_config.local_batch_size,
        examples_per_user=data_config.examples_per_user,
        # examples_per_user = trainer_config.users_per_round,
        drop_last=False,
    )

    metrics_reporter = MetricsReporter([Channel.TENSORBOARD, Channel.STDOUT])

    final_model, eval_score = trainer.train(
        data_provider=data_provider,
        metrics_reporter=metrics_reporter,
        num_total_users=data_provider.num_train_users(),
        distributed_world_size=1,
    )

    trainer.test(
        data_provider=data_provider,
        metrics_reporter=MetricsReporter([Channel.STDOUT]),
    )


@hydra.main(config_path=None, config_name="cifar10_tutorial")
def run(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))

    trainer_config = cfg.trainer
    data_config = cfg.data

    main(
        trainer_config,
        data_config,
    )


if __name__ == "__main__":
    cfg = maybe_parse_json_config()
    run(cfg)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant