diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 635789de6..2d783966c 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -62,7 +62,12 @@ echo "Using pip executable: $PIP_EXECUTABLE" # NOTE: If a newly-fetched version of the executorch repo changes the value of # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary # package versions. -PYTORCH_NIGHTLY_VERSION=dev20241002 +if [[ -x "$(command -v xpu-smi)" ]]; +then + PYTORCH_NIGHTLY_VERSION=dev20241001 +else + PYTORCH_NIGHTLY_VERSION=dev20241002 +fi # Nightly version for torchvision VISION_NIGHTLY_VERSION=dev20241002 @@ -85,16 +90,28 @@ then elif [[ -x "$(command -v rocminfo)" ]]; then TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2" +elif [[ -x "$(command -v xpu-smi)" ]]; +then + TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu" else TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu" fi # pip packages needed by exir. -REQUIREMENTS_TO_INSTALL=( - torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}" - torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" - torchtune=="0.4.0.${TUNE_NIGHTLY_VERSION}" -) +if [[ -x "$(command -v xpu-smi)" ]]; +then + REQUIREMENTS_TO_INSTALL=( + torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}" + torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" + torchtune=="0.3.1" + ) +else + REQUIREMENTS_TO_INSTALL=( + torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}" + torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}" + torchtune=="0.4.0.${TUNE_NIGHTLY_VERSION}" + ) +fi # Install the requirements. --extra-index-url tells pip to look for package # versions on the provided URL if they aren't available on the default URL. diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index f67cb9d0a..02fe9b47a 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -63,7 +63,12 @@ class BuilderArgs: def __post_init__(self): if self.device is None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.xpu.is_available(): + self.device = "xpu" + else: + self.device = "cpu" if not ( (self.checkpoint_path and self.checkpoint_path.is_file()) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index f7d00181b..78be01a92 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -168,8 +168,8 @@ def _add_model_config_args(parser, verb: str) -> None: "--device", type=str, default=None, - choices=["fast", "cpu", "cuda", "mps"], - help="Hardware device to use. Options: fast, cpu, cuda, mps", + choices=["fast", "cpu", "cuda", "mps", "xpu"], + help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu", ) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index 2685ec2f3..a0862ff94 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -231,6 +231,8 @@ def find_multiple(n: int, k: int) -> int: def device_sync(device="cpu"): if "cuda" in device: torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) elif ("cpu" in device) or ("mps" in device): pass else: @@ -279,7 +281,8 @@ def get_device_str(device) -> str: device = ( "cuda" if torch.cuda.is_available() - else "mps" if is_mps_available() else "cpu" + else "mps" if is_mps_available() + else "xpu" if torch.xpu.is_available() else "cpu" ) return device else: @@ -291,7 +294,8 @@ def get_device(device) -> str: device = ( "cuda" if torch.cuda.is_available() - else "mps" if is_mps_available() else "cpu" + else "mps" if is_mps_available() + else "xpu" if torch.xpu.is_available() else "cpu" ) return torch.device(device) diff --git a/torchchat/utils/device_info.py b/torchchat/utils/device_info.py index 9c5953944..950c03002 100644 --- a/torchchat/utils/device_info.py +++ b/torchchat/utils/device_info.py @@ -14,7 +14,7 @@ def get_device_info(device: str) -> str: """Returns a human-readable description of the hardware based on a torch.device.type Args: - device: A torch.device.type string: one of {"cpu", "cuda"}. + device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}. Returns: str: A human-readable description of the hardware or an empty string if the device type is unhandled. @@ -37,4 +37,13 @@ def get_device_info(device: str) -> str: ) if device == "cuda": return torch.cuda.get_device_name(0) + if device == "xpu": + return ( + check_output( + ["xpu-smi discovery |grep 'Device Name:'"], shell=True + ) + .decode("utf-8") + .split("\n")[0] + .split("Device Name:")[1] + ) return ""