diff --git a/test/espnet2/train/test_reporter.py b/test/espnet2/train/test_reporter.py index c928c52523a..9cd796d665c 100644 --- a/test/espnet2/train/test_reporter.py +++ b/test/espnet2/train/test_reporter.py @@ -53,7 +53,7 @@ def test_register(weight1, weight2): desired[k] /= weight1 + weight2 for k1, k2 in reporter.get_all_keys(): - if k2 in ("time", "total_count"): + if k2 in ("time", "total_count", "gpu_max_cached_mem_GB", "gpu_cached_mem_GB"): continue np.testing.assert_allclose(reporter.get_value(k1, k2), desired[k2])