Skip to content

Commit cb0bf07

Browse files
MNT Remove deprecated use of load_in_8bit (#1811)
Don't pass load_in_8bit to AutoModel.from_pretrained, instead use BitsAndBytesConfig. There was already a PR to clean this up (#1552) but a slightly later PR (#1518) re-added this usage. --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
1 parent 8cd2cb6 commit cb0bf07

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

tests/test_common_gpu.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -826,11 +826,7 @@ def test_8bit_lora_mixed_adapter_batches_lora(self):
826826
# check that we can pass mixed adapter names to the model
827827
# note that with 8bit, we have quite a bit of imprecision, therefore we use softmax and higher tolerances
828828
torch.manual_seed(3000)
829-
bnb_config = BitsAndBytesConfig(
830-
load_in_8bit=True,
831-
bnb_4bit_use_double_quant=False,
832-
bnb_4bit_compute_dtype=torch.float32,
833-
)
829+
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
834830
model = AutoModelForCausalLM.from_pretrained(
835831
"facebook/opt-125m",
836832
quantization_config=bnb_config,
@@ -951,7 +947,7 @@ def test_8bit_dora_inference(self):
951947
# check for same result with and without DoRA when initializing with init_lora_weights=False
952948
model = AutoModelForCausalLM.from_pretrained(
953949
"facebook/opt-125m",
954-
load_in_8bit=True,
950+
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
955951
torch_dtype=torch.float32,
956952
).eval()
957953

@@ -964,7 +960,7 @@ def test_8bit_dora_inference(self):
964960

965961
model = AutoModelForCausalLM.from_pretrained(
966962
"facebook/opt-125m",
967-
load_in_8bit=True,
963+
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
968964
torch_dtype=torch.float32,
969965
)
970966
torch.manual_seed(0)
@@ -1042,7 +1038,7 @@ def test_8bit_dora_merging(self):
10421038
torch.manual_seed(0)
10431039
model = AutoModelForCausalLM.from_pretrained(
10441040
"facebook/opt-125m",
1045-
load_in_8bit=True,
1041+
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
10461042
torch_dtype=torch.float32,
10471043
).eval()
10481044

0 commit comments

Comments
 (0)