Skip to content

Commit

Permalink
add device choice
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed May 30, 2024
1 parent 728f8c2 commit 7f3bdd1
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions wenet/bin/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def get_labformat(timestamp, subsample):
parser.add_argument('--device',
type=str,
default="cpu",
choices=["cpu", "npu", "cuda"],
help='accelerator to use')
parser.add_argument('--blank_thres',
default=0.999999,
Expand Down
1 change: 1 addition & 0 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_args():
parser.add_argument('--device',
type=str,
default="cpu",
choices=["cpu", "npu", "cuda"],
help='accelerator to use')
parser.add_argument('--dtype',
type=str,
Expand Down
2 changes: 2 additions & 0 deletions wenet/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def get_args():
help='Engine for paralleled training')
# set default value of device to "cuda", avoiding the modify of original scripts
parser.add_argument('--device',
type=str,
default='cuda',
choices=["cpu", "npu", "cuda"],
help='accelerator for training')
parser = add_model_args(parser)
parser = add_dataset_args(parser)
Expand Down
1 change: 1 addition & 0 deletions wenet/cli/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_args():
parser.add_argument('--device',
type=str,
default='cpu',
choices=["cpu", "npu", "cuda"],
help='accelerator to use')
parser.add_argument('-t',
'--show_tokens_info',
Expand Down
5 changes: 3 additions & 2 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,9 @@ def is_torch_npu_available() -> bool:
import torch_npu # noqa
return True
except ImportError:
print("Module \"torch_npu\" not found. \"pip install torch_npu\" \
if you are using Ascend NPU, otherwise, ignore it")
if not torch.cuda.is_available():
print("Module \"torch_npu\" not found. \"pip install torch_npu\" \
if you are using Ascend NPU, otherwise, ignore it")
return False


Expand Down

0 comments on commit 7f3bdd1

Please sign in to comment.