Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype bug when offload_state_dict=True and dtype is specified #2116

Merged
merged 7 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def set_module_tensor_to_device(
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
new_value = new_value.to(dtype)
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
new_value = new_value.to(dtype)

if not is_buffer:
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
elif isinstance(value, torch.Tensor):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, self.weight, self.bias)


class ModelSeveralDtypes(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("int_param", torch.randint(high=10, size=(15, 30)))
self.register_parameter("float_param", torch.nn.Parameter(torch.rand(10, 5)))

def forward(self, x):
return x + 2


def sequential_model(num_layers):
layers = OrderedDict([(f"linear{i}", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)])
return nn.Sequential(layers)
Expand Down Expand Up @@ -425,6 +435,19 @@ def test_load_checkpoint_in_model_two_gpu(self):
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
self.assertEqual(model.linear2.weight.device, torch.device(1))

def test_load_checkpoint_in_model_dtype(self):
with tempfile.NamedTemporaryFile(suffix=".pt") as tmpfile:
model = ModelSeveralDtypes()
torch.save(model.state_dict(), tmpfile.name)

new_model = ModelSeveralDtypes()
load_checkpoint_in_model(
new_model, tmpfile.name, offload_state_dict=True, dtype=torch.float16, device_map={"": "cpu"}
)

self.assertEqual(new_model.int_param.dtype, torch.int64)
self.assertEqual(new_model.float_param.dtype, torch.float16)

def test_clean_device_map(self):
# Regroup everything if all is on the same device
self.assertDictEqual(clean_device_map({"a": 0, "b": 0, "c": 0}), {"": 0})
Expand Down