Skip to content

Commit

Permalink
Model parallel (#538)
Browse files Browse the repository at this point in the history
* model-parallel for single process

* fix device/device_map

* fix handling for device
  • Loading branch information
winglian authored Sep 13, 2023
1 parent a4e1bb6 commit f6060a6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):


def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
if not torch.cuda.is_available() or device == "auto":
return (0, 0, 0)

usage, cache, misc = gpu_memory_usage_all(device)
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def get_device():
return "cpu"

cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.world_size == 1:
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
Expand Down

0 comments on commit f6060a6

Please sign in to comment.