Skip to content

Commit

Permalink
update test for utils functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Aug 1, 2024
1 parent def0df2 commit 7f9a3e1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# .coveragerc to control coverage.py
[run]
omit =
# omit all tests
# omit all test files
tests/*
# omit algorithms, which will be tested separately
fl_sim/algorithms/*
Expand All @@ -12,6 +12,7 @@ omit =
fl_sim/data_processing/libsvm_datasets.py
# visualization panel requires notebook environment
fl_sim/utils/viz.py
# fl_sim/utils/torch_compat.py
# compressors not used currently
fl_sim/compressors/*

Expand Down
34 changes: 33 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
make_serializable,
ordered_dict_to_dict,
)
from fl_sim.utils.torch_compat import torch_norm


def test_get_scheduler():
Expand Down Expand Up @@ -226,5 +227,36 @@ def test_experiment_indicator():


def test_url_is_reachable():
assert url_is_reachable("https://www.google.com/")
assert url_is_reachable("https://www.cloudflare.com/")
assert not url_is_reachable("https://www.google.com/xxx")


def test_torch_norm():
a = torch.arange(9, dtype=torch.float) - 4
b = a.reshape((3, 3))
assert torch.allclose(torch_norm(a, dtype=torch.float64), torch.tensor(7.7460, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, dtype=torch.float64), torch.tensor(7.7460, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(a, float("inf"), dtype=torch.float64), torch.tensor(4.0, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, float("inf"), dtype=torch.float64), torch.tensor(4.0, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(b, p="nuc", dtype=torch.float64), torch.tensor(9.7980, dtype=torch.float64), atol=1e-4)

c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float)
assert torch.allclose(
torch_norm(c, dim=0, dtype=torch.float64), torch.tensor([1.4142, 2.2361, 5.0000], dtype=torch.float64), atol=1e-4
)
assert torch.allclose(
torch_norm(c, dim=1, dtype=torch.float64), torch.tensor([3.7417, 4.2426], dtype=torch.float64), atol=1e-4
)
assert torch.allclose(
torch_norm(c, p=1, dim=1, dtype=torch.float64), torch.tensor([6.0, 6.0], dtype=torch.float64), atol=1e-4
)

d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
assert torch.allclose(
torch_norm(d, dim=(1, 2), dtype=torch.float64), torch.tensor([3.7417, 11.2250], dtype=torch.float64), atol=1e-4
)
assert torch.allclose(torch_norm(d[0, :, :], dtype=torch.float64), torch.tensor(3.7417, dtype=torch.float64), atol=1e-4)
assert torch.allclose(torch_norm(d[1, :, :], dtype=torch.float64), torch.tensor(11.2250, dtype=torch.float64), atol=1e-4)
assert torch.allclose(
torch_norm(d, p="nuc", dim=[1, 2], dtype=torch.float64), torch.tensor([4.2426, 11.4018], dtype=torch.float64), atol=1e-4
)

0 comments on commit 7f9a3e1

Please sign in to comment.