Skip to content

Commit

Permalink
update test case to take into account param requirements and blocksize
Browse files Browse the repository at this point in the history
  • Loading branch information
asahni04 committed Nov 7, 2024
1 parent a4e320c commit 668a4ba
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,11 @@ def test_optim_exclude_low_bit_params(self, optim_name, dtype, device):
self.assertTrue(exp_avg.__class__ == torch.Tensor)
self.assertTrue(exp_avg_sq.__class__ == torch.Tensor)
for param in model.parameters():
if id(param) not in excluded_params_ids :
if id(param) not in excluded_params_ids and param.numel() >= 4096 and param.numel() % optim.block_size == 0:
param_state = state[param]
exp_avg = param_state['exp_avg']
exp_avg_sq = param_state['exp_avg_sq']

self.assertTrue(exp_avg.__class__ != torch.Tensor)
self.assertTrue(exp_avg_sq.__class__ != torch.Tensor)

Expand Down

0 comments on commit 668a4ba

Please sign in to comment.