Skip to content

Commit

Permalink
replace .to(<num>) with .to("npu:<num>") when using torch_npu`
Browse files Browse the repository at this point in the history
  • Loading branch information
ji-huazhong committed Dec 7, 2023
1 parent 3fb6dcc commit 61a105f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/accelerate/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
find_tied_parameters,
get_balanced_memory,
infer_auto_device_map,
is_torch_version,
is_npu_available,
is_torch_version,
load_checkpoint_in_model,
offload_state_dict,
parse_flag_from_env,
Expand Down Expand Up @@ -435,6 +435,9 @@ def wrapper(*args, **kwargs):

else:
device = list(device_map.values())[0]
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
if device != "disk":
model.to(device)
else:
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def set_module_tensor_to_device(
):
device_quantization = device
device = "cpu"
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(device, int) and is_npu_available():
device = f"npu:{device}"
if value is None:
new_value = old_value.to(device)
if dtype is not None and device in ["meta", torch.device("meta")]:
Expand Down
5 changes: 4 additions & 1 deletion src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..state import PartialState
from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES
from .dataclasses import DistributedType, TensorInformation
from .imports import is_torch_distributed_available, is_torch_version, is_tpu_available
from .imports import is_npu_available, is_torch_distributed_available, is_torch_version, is_tpu_available


if is_tpu_available(check_device=False):
Expand Down Expand Up @@ -164,6 +164,9 @@ def send_to_device(tensor, device, non_blocking=False, skip_keys=None):
}
)
elif hasattr(tensor, "to"):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if is_npu_available() and isinstance(device, int):
device = f"npu:{device}"
try:
return tensor.to(device, non_blocking=non_blocking)
except TypeError: # .to() doesn't accept non_blocking as kwarg
Expand Down

0 comments on commit 61a105f

Please sign in to comment.