diff --git a/README.md b/README.md index 093b8210d..21fb6f4e3 100644 --- a/README.md +++ b/README.md @@ -612,6 +612,12 @@ eval_sample_packing: sample_packing_eff_est: total_num_tokens: +# Passed through to transformers when loading the model when launched without accelerate +# Use `sequential` when training w/ model parallelism to limit memory +device_map: +# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model. +max_memory: + # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model adapter: lora # If you already have a lora model trained that you want to load, put that here. diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ef8025a3e..120ea7918 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -27,7 +27,7 @@ def get_device(): cfg.device = get_device() if cfg.world_size == 1: - cfg.device_map = "auto" + cfg.device_map = cfg.device_map or "auto" else: if cfg.device.startswith("cuda"): cfg.device_map = {"": torch.cuda.current_device()} diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0d8c812f3..9f33523e3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -216,6 +216,7 @@ def load_model( model_kwargs = {} model_kwargs["device_map"] = cfg.device_map + model_kwargs["max_memory"] = cfg.max_memory model_kwargs["torch_dtype"] = cfg.torch_dtype if cfg.model_revision: