Skip to content

Commit

Permalink
Support device_map=sequential & max_memory config parameters (axolotl…
Browse files Browse the repository at this point in the history
…-ai-cloud#903)

* Support device_map sequential (and others). Support max_memory in cfg.

* Update documentation in README accordingly.

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
  • Loading branch information
brthor and winglian authored Dec 4, 2023
1 parent 9871ecf commit 2189925
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,12 @@ eval_sample_packing:
sample_packing_eff_est:
total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora
# If you already have a lora model trained that you want to load, put that here.
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_device():

cfg.device = get_device()
if cfg.world_size == 1:
cfg.device_map = "auto"
cfg.device_map = cfg.device_map or "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def load_model(
model_kwargs = {}

model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype

if cfg.model_revision:
Expand Down

0 comments on commit 2189925

Please sign in to comment.