diff --git a/test/test_data_processing.py b/test/test_data_processing.py index 98e4872..c3353bd 100644 --- a/test/test_data_processing.py +++ b/test/test_data_processing.py @@ -10,7 +10,7 @@ import pytest import torch -from fl_sim.data_processing import ( # noqa: F401; abstract base classes; datasets from FedML; FedCIFAR, # the same as FedCIFAR currently; datasets from FedProx; libsvm datasets +from fl_sim.data_processing import ( # noqa: F401 FedCIFAR100, FedDataset, FedEMNIST, @@ -25,6 +25,7 @@ FedShakespeare, FedSynthetic, FedVisionDataset, + TinyImageNet, libsvmread, ) from fl_sim.data_processing.fed_dataset import NLPDataset @@ -42,7 +43,6 @@ def test_FedVisionDataset(): @torch.no_grad() def test_FedCIFAR100(): - """ """ ds = FedCIFAR100() assert ds.n_class == 100 assert len(ds._client_ids_train) == ds.DEFAULT_TRAIN_CLIENTS_NUM @@ -358,7 +358,6 @@ def test_FedProxSent140(): @torch.no_grad() def test_FedProxFEMNIST(): - """ """ ds = FedProxFEMNIST() assert str(ds) == repr(ds) @@ -405,7 +404,6 @@ def test_FedProxFEMNIST(): @torch.no_grad() def test_FedProxMNIST(): - """ """ ds = FedProxMNIST() assert str(ds) == repr(ds) @@ -452,7 +450,6 @@ def test_FedProxMNIST(): @torch.no_grad() def test_FedSynthetic(): - """ """ ds = FedSynthetic(1, 1, False, 30) assert repr(ds) == str(ds) @@ -540,3 +537,44 @@ def test_NLPDataset(): ds = NLPDataset.from_huggingface_dataset("sst2", split="train") assert len(ds) > 0 + + +@torch.no_grad() +def test_TinyImageNet(): + ds = TinyImageNet() + + assert str(ds) == repr(ds) + assert isinstance(ds.doi, list) and all(isinstance(d, str) for d in ds.doi) + + assert ds.n_class == 200 + + ds.view_image(0, 0) + ds.random_grid_view(3, 3, save_path="test_TinyImageNet.pdf") + + train_dl, test_dl = ds.get_dataloader(client_idx=0) + assert len(train_dl) > 0 and len(test_dl) > 0 + + train_dl, test_dl = ds.get_dataloader(client_idx=None) + assert len(train_dl) > 0 and len(test_dl) > 0 + + candidate_models = ds.candidate_models + assert len(candidate_models) > 0 and isinstance(candidate_models, dict) + for model_name, model in candidate_models.items(): + assert isinstance(model_name, str) and isinstance(model, torch.nn.Module) + + model.eval() + batch = next(iter(train_dl)) + loss = ds.criterion(model(batch[0]), batch[1]) + assert isinstance(loss, torch.Tensor) and loss.dim() == 0 + + prob = model.predict_proba(batch[0][0], batched=False) + pred = model.predict(batch[0][0], batched=False) + assert isinstance(prob, np.ndarray) and prob.ndim == 1 + assert isinstance(pred, int) + + prob = model.predict_proba(batch[0], batched=True) + pred = model.predict(batch[0], batched=True) + assert isinstance(prob, np.ndarray) and prob.ndim == 2 + assert isinstance(pred, list) and len(pred) == batch[0].shape[0] + eval_res = ds.evaluate(torch.from_numpy(prob), batch[1]) + assert isinstance(eval_res, dict) and len(eval_res) > 0