Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallel support #2354

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere
float16: true

# Use Tensor parallel
tensor_parallel: true # require multi-gGPU

# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,9 @@ def build(self, total_num_steps):
"accelerator_config"
] = self.cfg.accelerator_config

if self.cfg.tensor_parallel:
training_arguments_kwargs["tp_size"] = torch.cuda.device_count()

if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ class AxolotlInputConfig(
local_rank: Optional[int] = None
ddp: Optional[bool] = None

tensor_parallel: Optional[bool] = None

seed: Optional[int] = None
ddp_timeout: Optional[int] = None
ddp_bucket_cap_mb: Optional[int] = None
Expand Down Expand Up @@ -1371,6 +1373,13 @@ def check_peft_layers_pattern(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_fsdp_tp(cls, data):
if data.get("fsdp") and data.get("tensor_parallel"):
raise ValueError("FSDP with tensor parallelism is not supported yet.")
return data

@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ def _configure_zero3_memory_efficient_loading():
return hf_ds_cfg

skip_move_to_device = False
if self.cfg.tensor_parallel:
del self.model_kwargs["device_map"]

if ( # pylint: disable=condition-evals-to-constant)
(self.cfg.fsdp and self.cfg.fsdp_config.fsdp_cpu_ram_efficient_loading)
and not qlora_fsdp
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def prepare_optim_env(cfg):
if not check_cuda_p2p_ib_support():
if os.getenv("NCCL_P2P_DISABLE") is None:
os.environ["NCCL_P2P_DISABLE"] = "1"

if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
Expand Down
Loading