Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy import of bitsandbytes #1230

Merged
merged 9 commits into from
Dec 7, 2023
Merged
45 changes: 45 additions & 0 deletions scripts/launch_notebook_mp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is a minimal example of launching PEFT with Accelerate. This used to cause issues because PEFT would eagerly
# import bitsandbytes, which initializes CUDA, resulting in:
# > RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the
# > 'spawn' start method
# This script exists to ensure that this issue does not reoccur.

import torch
import peft
from accelerate import notebook_launcher


def init():
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 2)

def forward(self, x):
return self.linear(x)

model = MyModule().to("cuda")
peft.get_peft_model(model, peft.LoraConfig(target_modules=["linear"]))


def main():
notebook_launcher(init, (), num_processes=2)


if __name__ == "__main__":
main()
15 changes: 9 additions & 6 deletions src/peft/tuners/adalora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]


if is_bnb_available():
from .bnb import SVDLinear8bitLt
def __getattr__(name):
if (name == "SVDLinear8bitLt") and is_bnb_available():
from .bnb import SVDLinear8bitLt

__all__ += ["SVDLinear8bitLt"]
return SVDLinear8bitLt

if is_bnb_4bit_available():
from .bnb import SVDLinear4bit
if (name == "SVDLinear4bit") and is_bnb_4bit_available():
from .bnb import SVDLinear4bit

__all__ += ["SVDLinear4bit"]
return SVDLinear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
16 changes: 8 additions & 8 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@
from .layer import AdaLoraLayer, RankAllocator, SVDLinear


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import SVDLinear8bitLt
if is_bnb_4bit_available():
from .bnb import SVDLinear4bit


class AdaLoraModel(LoraModel):
"""
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
Expand Down Expand Up @@ -159,6 +151,14 @@ def _create_and_replace(

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import SVDLinear8bitLt
if is_bnb_4bit_available():
from .bnb import SVDLinear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

Expand Down
15 changes: 9 additions & 6 deletions src/peft/tuners/ia3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
__all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"]


if is_bnb_available():
from .bnb import Linear8bitLt
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

__all__ += ["Linear8bitLt"]
return Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

__all__ += ["Linear4bit"]
return Linear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
18 changes: 9 additions & 9 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
from .layer import Conv2d, IA3Layer, Linear


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit


class IA3Model(BaseTuner):
"""
Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained
Expand Down Expand Up @@ -86,6 +77,15 @@ def __init__(self, model, config, adapter_name):

@staticmethod
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit

loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
is_feedforward = kwargs.pop("is_feedforward", False)
Expand Down
15 changes: 9 additions & 6 deletions src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]


if is_bnb_available():
from .bnb import Linear8bitLt
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

__all__ += ["Linear8bitLt"]
return Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

__all__ += ["Linear4bit"]
return Linear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
18 changes: 9 additions & 9 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@
from .layer import Conv2d, Embedding, Linear, LoraLayer


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit


class LoraModel(BaseTuner):
"""
Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
Expand Down Expand Up @@ -253,6 +244,15 @@ def _mark_only_adapters_as_trainable(self) -> None:

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest
import torch
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from datasets import Audio, DatasetDict, load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -933,3 +935,13 @@ def test_causal_lm_training_mutli_gpu(self):

# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])


@require_bitsandbytes
@require_torch_gpu
class MultiprocessTester(unittest.TestCase):
def test_notebook_launcher(self):
script_path = os.path.join("scripts", "launch_notebook_mp.py")
cmd = ["python", script_path]
with patch_environment(omp_num_threads=1):
run_command(cmd, env=os.environ.copy())
Loading