Skip to content

Commit

Permalink
add validation to prevent 8bit lora finetuning on H100s (#1827)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 17, 2024
1 parent 803fed3 commit b1d2921
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,19 @@ def check_sample_packing_w_sdpa_bf16(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")

return data

@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):
Expand Down

0 comments on commit b1d2921

Please sign in to comment.