Skip to content

Commit

Permalink
take torch.nn.Module model into account when moving to device (#3167)
Browse files Browse the repository at this point in the history
* bug fix

* update code
  • Loading branch information
faaany authored Oct 31, 2024
1 parent ffbca15 commit 87732a4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@ def _prepare_ipex_or_xpu(self, *args):
optimizer = obj
if optimizer is not None and model is not None:
dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None
if self.device.type == "xpu" and model.device.type == "cpu":
if self.device.type == "xpu" and next(model.parameters()).device.type == "cpu":
model = model.to(self.device)
# ipex.optimize() is available only for IPEX, both IPEX-CPU and IPEX-XPU
if is_ipex_available():
Expand Down

0 comments on commit 87732a4

Please sign in to comment.