From 668a4ba1b767bd11d6a7bf9cd2c5d97ede7bf7f9 Mon Sep 17 00:00:00 2001 From: asahni Date: Wed, 6 Nov 2024 18:14:39 -0800 Subject: [PATCH] update test case to take into account param requirements and blocksize --- test/prototype/test_low_bit_optim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 16202451a6..c87f482dbb 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -390,10 +390,11 @@ def test_optim_exclude_low_bit_params(self, optim_name, dtype, device): self.assertTrue(exp_avg.__class__ == torch.Tensor) self.assertTrue(exp_avg_sq.__class__ == torch.Tensor) for param in model.parameters(): - if id(param) not in excluded_params_ids : + if id(param) not in excluded_params_ids and param.numel() >= 4096 and param.numel() % optim.block_size == 0: param_state = state[param] exp_avg = param_state['exp_avg'] exp_avg_sq = param_state['exp_avg_sq'] + self.assertTrue(exp_avg.__class__ != torch.Tensor) self.assertTrue(exp_avg_sq.__class__ != torch.Tensor)