Skip to content

Commit d6f4bae

Browse files
committed
fixes tests
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent ec72d31 commit d6f4bae

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

tests/test_compose.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,16 @@ def test_data_loader(self):
169169
set_determinism(seed=123)
170170
train_loader = DataLoader(train_ds, num_workers=0)
171171
out_1 = next(iter(train_loader))
172-
self.assertAlmostEqual(out_1.cpu().item(), 0.84291356)
172+
self.assertAlmostEqual(out_1.cpu().item(), 0.0409280)
173173

174174
if sys.platform != "win32": # skip multi-worker tests on win32
175175
train_loader = DataLoader(train_ds, num_workers=1)
176176
out_1 = next(iter(train_loader))
177-
self.assertAlmostEqual(out_1.cpu().item(), 0.180814653)
177+
self.assertAlmostEqual(out_1.cpu().item(), 0.78663897075)
178178

179179
train_loader = DataLoader(train_ds, num_workers=2)
180180
out_1 = next(iter(train_loader))
181-
self.assertAlmostEqual(out_1.cpu().item(), 0.04293707)
181+
self.assertAlmostEqual(out_1.cpu().item(), 0.785907334)
182182
set_determinism(None)
183183

184184
def test_data_loader_2(self):
@@ -191,16 +191,16 @@ def test_data_loader_2(self):
191191

192192
train_loader = DataLoader(train_ds, num_workers=0)
193193
out_2 = next(iter(train_loader))
194-
self.assertAlmostEqual(out_2.cpu().item(), 0.7858843729)
194+
self.assertAlmostEqual(out_2.cpu().item(), 0.98921915918)
195195

196196
if sys.platform != "win32": # skip multi-worker tests on win32
197197
train_loader = DataLoader(train_ds, num_workers=1)
198198
out_2 = next(iter(train_loader))
199-
self.assertAlmostEqual(out_2.cpu().item(), 0.305763411)
199+
self.assertAlmostEqual(out_2.cpu().item(), 0.32985207)
200200

201201
train_loader = DataLoader(train_ds, num_workers=2)
202202
out_1 = next(iter(train_loader))
203-
self.assertAlmostEqual(out_1.cpu().item(), 0.131966779)
203+
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
204204
set_determinism(None)
205205

206206
def test_flatten_and_len(self):

tests/test_grid_dataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_loading_array(self):
6060
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
6161
np.testing.assert_allclose(
6262
item[0],
63-
np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
63+
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
6464
rtol=1e-4,
6565
)
6666
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -69,7 +69,9 @@ def test_loading_array(self):
6969
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
7070
np.testing.assert_allclose(
7171
item[0],
72-
np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
72+
np.array(
73+
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
74+
),
7375
rtol=1e-3,
7476
)
7577
np.testing.assert_allclose(
@@ -102,7 +104,7 @@ def test_loading_dict(self):
102104
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
103105
np.testing.assert_allclose(
104106
item[0]["image"],
105-
np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
107+
np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
106108
rtol=1e-4,
107109
)
108110
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -111,7 +113,9 @@ def test_loading_dict(self):
111113
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
112114
np.testing.assert_allclose(
113115
item[0]["image"],
114-
np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
116+
np.array(
117+
[[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
118+
),
115119
rtol=1e-3,
116120
)
117121
np.testing.assert_allclose(

tests/test_patch_dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_loading_array(self):
5959
np.testing.assert_allclose(
6060
item[0],
6161
np.array(
62-
[[[1.338681, 2.338681, 3.338681], [5.338681, 6.338681, 7.338681], [9.338681, 10.338681, 11.338681]]]
62+
[[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]]
6363
),
6464
rtol=1e-5,
6565
)
@@ -69,13 +69,7 @@ def test_loading_array(self):
6969
np.testing.assert_allclose(
7070
item[0],
7171
np.array(
72-
[
73-
[
74-
[4.957847, 5.957847, 6.957847],
75-
[8.957847, 9.957847, 10.957847],
76-
[12.957847, 13.957847, 14.957847],
77-
]
78-
]
72+
[[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]]
7973
),
8074
rtol=1e-5,
8175
)

0 commit comments

Comments
 (0)