Skip to content

Commit

Permalink
Fix for weights-only load (pytorch#1228)
Browse files Browse the repository at this point in the history
stack-info: PR: pytorch#1228, branch: drisspg/stack/19
  • Loading branch information
drisspg authored and asahni04 committed Dec 5, 2024
1 parent 9372454 commit ed83e26
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
)

from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_6,
)

try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -199,7 +202,9 @@ def test_optim_8bit_correctness(self, optim_name):
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
optim2 = getattr(low_bit_optim, optim_name)(
model2.parameters(), block_size=block_size
)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down Expand Up @@ -315,7 +320,9 @@ def test_optim_cpu_offload_save_load(self):

# resume training
model2 = copy.deepcopy(model1)
optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW)
optim2 = low_bit_optim.CPUOffloadOptimizer(
model2.parameters(), torch.optim.AdamW
)
optim2.load_state_dict(state_dict)

for _ in range(2):
Expand Down

0 comments on commit ed83e26

Please sign in to comment.