diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index e87265a68..12c272a2b 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -139,6 +139,7 @@ def get_labformat(timestamp, subsample): parser.add_argument('--device', type=str, default="cpu", + choices=["cpu", "npu", "cuda"], help='accelerator to use') parser.add_argument('--blank_thres', default=0.999999, diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index 1ba1eff3f..3101d6eb3 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -47,6 +47,7 @@ def get_args(): parser.add_argument('--device', type=str, default="cpu", + choices=["cpu", "npu", "cuda"], help='accelerator to use') parser.add_argument('--dtype', type=str, diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 94f4c93d5..a77021c67 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -47,7 +47,9 @@ def get_args(): help='Engine for paralleled training') # set default value of device to "cuda", avoiding the modify of original scripts parser.add_argument('--device', + type=str, default='cuda', + choices=["cpu", "npu", "cuda"], help='accelerator for training') parser = add_model_args(parser) parser = add_dataset_args(parser) diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index 5dde74546..28bf27919 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -41,6 +41,7 @@ def get_args(): parser.add_argument('--device', type=str, default='cpu', + choices=["cpu", "npu", "cuda"], help='accelerator to use') parser.add_argument('-t', '--show_tokens_info', diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 716bfcba7..41488d5c7 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -368,8 +368,9 @@ def is_torch_npu_available() -> bool: import torch_npu # noqa return True except ImportError: - print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ - if you are using Ascend NPU, otherwise, ignore it") + if not torch.cuda.is_available(): + print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ + if you are using Ascend NPU, otherwise, ignore it") return False