Skip to content

Commit

Permalink
add test_TinyImageNet
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Jul 31, 2024
1 parent 41b9b07 commit 658d8e5
Showing 1 changed file with 43 additions and 5 deletions.
48 changes: 43 additions & 5 deletions test/test_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
import torch

from fl_sim.data_processing import ( # noqa: F401; abstract base classes; datasets from FedML; FedCIFAR, # the same as FedCIFAR currently; datasets from FedProx; libsvm datasets
from fl_sim.data_processing import ( # noqa: F401
FedCIFAR100,
FedDataset,
FedEMNIST,
Expand All @@ -25,6 +25,7 @@
FedShakespeare,
FedSynthetic,
FedVisionDataset,
TinyImageNet,
libsvmread,
)
from fl_sim.data_processing.fed_dataset import NLPDataset
Expand All @@ -42,7 +43,6 @@ def test_FedVisionDataset():

@torch.no_grad()
def test_FedCIFAR100():
""" """
ds = FedCIFAR100()
assert ds.n_class == 100
assert len(ds._client_ids_train) == ds.DEFAULT_TRAIN_CLIENTS_NUM
Expand Down Expand Up @@ -358,7 +358,6 @@ def test_FedProxSent140():

@torch.no_grad()
def test_FedProxFEMNIST():
""" """
ds = FedProxFEMNIST()

assert str(ds) == repr(ds)
Expand Down Expand Up @@ -405,7 +404,6 @@ def test_FedProxFEMNIST():

@torch.no_grad()
def test_FedProxMNIST():
""" """
ds = FedProxMNIST()

assert str(ds) == repr(ds)
Expand Down Expand Up @@ -452,7 +450,6 @@ def test_FedProxMNIST():

@torch.no_grad()
def test_FedSynthetic():
""" """
ds = FedSynthetic(1, 1, False, 30)

assert repr(ds) == str(ds)
Expand Down Expand Up @@ -540,3 +537,44 @@ def test_NLPDataset():

ds = NLPDataset.from_huggingface_dataset("sst2", split="train")
assert len(ds) > 0


@torch.no_grad()
def test_TinyImageNet():
ds = TinyImageNet()

assert str(ds) == repr(ds)
assert isinstance(ds.doi, list) and all(isinstance(d, str) for d in ds.doi)

assert ds.n_class == 200

ds.view_image(0, 0)
ds.random_grid_view(3, 3, save_path="test_TinyImageNet.pdf")

train_dl, test_dl = ds.get_dataloader(client_idx=0)
assert len(train_dl) > 0 and len(test_dl) > 0

train_dl, test_dl = ds.get_dataloader(client_idx=None)
assert len(train_dl) > 0 and len(test_dl) > 0

candidate_models = ds.candidate_models
assert len(candidate_models) > 0 and isinstance(candidate_models, dict)
for model_name, model in candidate_models.items():
assert isinstance(model_name, str) and isinstance(model, torch.nn.Module)

model.eval()
batch = next(iter(train_dl))
loss = ds.criterion(model(batch[0]), batch[1])
assert isinstance(loss, torch.Tensor) and loss.dim() == 0

prob = model.predict_proba(batch[0][0], batched=False)
pred = model.predict(batch[0][0], batched=False)
assert isinstance(prob, np.ndarray) and prob.ndim == 1
assert isinstance(pred, int)

prob = model.predict_proba(batch[0], batched=True)
pred = model.predict(batch[0], batched=True)
assert isinstance(prob, np.ndarray) and prob.ndim == 2
assert isinstance(pred, list) and len(pred) == batch[0].shape[0]
eval_res = ds.evaluate(torch.from_numpy(prob), batch[1])
assert isinstance(eval_res, dict) and len(eval_res) > 0

0 comments on commit 658d8e5

Please sign in to comment.