Skip to content

Commit

Permalink
Add Ascend NPU support
Browse files Browse the repository at this point in the history
  1. add Ascend NPU backend support
  2. refactor func load_model in src/axolotl/utils/models.py
  3. refactor load_in_8bit as a kwarg
  • Loading branch information
MengqingCao committed Jul 19, 2024
1 parent 7830fe0 commit 0bfa6f7
Show file tree
Hide file tree
Showing 3 changed files with 577 additions and 446 deletions.
25 changes: 24 additions & 1 deletion src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
Expand All @@ -28,8 +29,10 @@ def get_device():

if torch.backends.mps.is_available():
return "mps"
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"

raise SystemError("No CUDA/mps device found")
raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"

Expand All @@ -39,6 +42,8 @@ def get_device():
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}

Expand Down Expand Up @@ -91,6 +96,24 @@ def normalize_config(cfg):
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
elif cfg.device.startswith("npu"):
if cfg.load_in_8bit or cfg.load_in_4bit:
LOG.warn("Quantification is currently not supported in npu, disabling for this configuration.")
cfg.load_in_8bit = False
cfg.load_in_4bit = False

if cfg.tf32:
LOG.warn("tf32 dtype is currently not supported in npu, disabling for this configuration.")
cfg.tf32 = False

if cfg.bf16:
LOG.warn("bf16 is currently not supported in npu, casting to fp16.")
cfg.fp16 = True
cfg.bf16 = False

if "bit" in cfg.optimizer:
LOG.error("{} is currently not supported in npu, choose another one.".format(cfg.optimizer))

else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
if cfg.bf16:
Expand Down
43 changes: 27 additions & 16 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@
import torch
import torch.distributed as dist
from accelerate import PartialState
from transformers.utils.import_utils import (
is_torch_npu_available,
is_torch_cuda_available,
is_torch_mps_available
)

distributed_state = None # pylint: disable=invalid-name

def get_device():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device


def is_distributed():
"""
Expand Down Expand Up @@ -83,12 +98,11 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
device = get_device()
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
value_tensor = torch.tensor(value_scalar, device=device).float()

if not is_main_process():
dist.gather(value_tensor, dst=0)
Expand All @@ -111,13 +125,14 @@ def broadcast_dict(vals: dict):
if not is_distributed():
return vals

device = get_device()
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
data_tensor = torch.ByteTensor(list(data_byte)).to(device)
data_size = torch.IntTensor([len(data_byte)]).to(device)
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
data_tensor = torch.empty([1024], dtype=torch.uint8, device=device)
data_size = torch.IntTensor([0]).to(device)

dist.broadcast(data_size, 0)
if not is_main_process():
Expand Down Expand Up @@ -146,15 +161,12 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
Returns:
- The computed value (int or float).
"""
device = get_device()
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
value_tensor = torch.tensor(value_scalar, device=device).float()
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
) # Placeholder tensor
value_tensor = torch.tensor(0.0, device=device) # Placeholder tensor

# Broadcast the tensor to all processes.
barrier()
Expand All @@ -178,10 +190,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
device = get_device()
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
value_tensor = torch.tensor(value_scalar, device=device).float()

# Placeholder tensor for gathering results
if is_main_process():
Expand Down
Loading

0 comments on commit 0bfa6f7

Please sign in to comment.