Skip to content

Commit 94b95e2

Browse files
committed
Add NPU Support for Single Agent
1 parent a0d650f commit 94b95e2

File tree

22 files changed

+59
-4
lines changed

22 files changed

+59
-4
lines changed

sota-implementations/a2c/a2c_atari.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ def main(cfg: DictConfig): # noqa: F821
3535

3636
device = cfg.loss.device
3737
if not device:
38-
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
38+
if torch.cuda.is_available():
39+
device = torch.device("cuda:0")
40+
elif torch.npu.is_available():
41+
device = torch.device("npu:0")
42+
else:
43+
device = torch.device("cpu")
3944
else:
4045
device = torch.device(device)
4146

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def main(cfg: DictConfig): # noqa: F821
3838

3939
device = cfg.loss.device
4040
if not device:
41-
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")
41+
if torch.cuda.is_available():
42+
device = torch.device("cuda:0")
43+
elif torch.npu.is_available():
44+
device = torch.device("npu:0")
45+
else:
46+
device = torch.device("cpu")
4247
else:
4348
device = torch.device(device)
4449

sota-implementations/cql/cql_offline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def main(cfg: DictConfig): # noqa: F821
5959
if device in ("", None):
6060
if torch.cuda.is_available():
6161
device = "cuda:0"
62+
elif torch.npu.is_available():
63+
device = "npu:0"
6264
else:
6365
device = "cpu"
6466
device = torch.device(device)

sota-implementations/cql/cql_online.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def main(cfg: DictConfig): # noqa: F821
6464
if device in ("", None):
6565
if torch.cuda.is_available():
6666
device = "cuda:0"
67+
elif torch.npu.is_available():
68+
device = "npu:0"
6769
else:
6870
device = "cpu"
6971
device = torch.device(device)

sota-implementations/cql/discrete_cql_offline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
def main(cfg): # noqa: F821
3939
device = cfg.optim.device
4040
if device in ("", None):
41-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
41+
if torch.cuda.is_available():
42+
device = "cuda:0"
43+
elif torch.npu.is_available():
44+
device = "npu:0"
45+
else:
46+
device = "cpu"
4247
device = torch.device(device)
4348

4449
# Create logger

sota-implementations/cql/discrete_cql_online.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def main(cfg: DictConfig): # noqa: F821
4242
if device in ("", None):
4343
if torch.cuda.is_available():
4444
device = "cuda:0"
45+
elif torch.npu.is_available():
46+
device = "npu:0"
4547
else:
4648
device = "cpu"
4749
device = torch.device(device)

sota-implementations/crossq/crossq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def main(cfg: DictConfig): # noqa: F821
4444
if device in ("", None):
4545
if torch.cuda.is_available():
4646
device = torch.device("cuda:0")
47+
elif torch.npu.is_available():
48+
device = torch.device("npu:0")
4749
else:
4850
device = torch.device("cpu")
4951
device = torch.device(device)

sota-implementations/ddpg/ddpg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def main(cfg: DictConfig): # noqa: F821
4343
if device in ("", None):
4444
if torch.cuda.is_available():
4545
device = "cuda:0"
46+
elif torch.npu.is_available():
47+
device = "npu:0"
4648
else:
4749
device = "cpu"
4850
device = torch.device(device)
@@ -51,6 +53,8 @@ def main(cfg: DictConfig): # noqa: F821
5153
if collector_device in ("", None):
5254
if torch.cuda.is_available():
5355
collector_device = "cuda:0"
56+
elif torch.npu.is_available():
57+
collector_device = "npu:0"
5458
else:
5559
collector_device = "cpu"
5660
collector_device = torch.device(collector_device)

sota-implementations/decision_transformer/dt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def main(cfg: DictConfig): # noqa: F821
4242
if model_device in ("", None):
4343
if torch.cuda.is_available():
4444
model_device = "cuda:0"
45+
elif torch.npu.is_available():
46+
model_device = "npu:0"
4547
else:
4648
model_device = "cpu"
4749
model_device = torch.device(model_device)

sota-implementations/decision_transformer/online_dt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def main(cfg: DictConfig): # noqa: F821
4040
if model_device in ("", None):
4141
if torch.cuda.is_available():
4242
model_device = "cuda:0"
43+
elif torch.npu.is_available():
44+
model_device = "npu:0"
4345
else:
4446
model_device = "cpu"
4547
model_device = torch.device(model_device)

0 commit comments

Comments
 (0)