Skip to content

Commit

Permalink
Lazy import of bitsandbytes (#1230)
Browse files Browse the repository at this point in the history
Previously, we imported from bitsandbytes eagerly if the package was
installed. This caused two major issues:

- Slow loading time of PEFT (~4 sec)
- Errors with multiprocessing because bnb initializes CUDA

This commit fixes both issues by importing bitsandbytes lazily. PEFT
import time is now reduced to ~2sec.

Notes

Implementation-wise, I use a combination of local imports and
module-level __getattr__. The latter was introduced in Python 3.7 and
should therefore be safe to use.
  • Loading branch information
BenjaminBossan authored Dec 7, 2023
1 parent 2ab005f commit b467e3d
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 44 deletions.
45 changes: 45 additions & 0 deletions scripts/launch_notebook_mp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This is a minimal example of launching PEFT with Accelerate. This used to cause issues because PEFT would eagerly
# import bitsandbytes, which initializes CUDA, resulting in:
# > RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the
# > 'spawn' start method
# This script exists to ensure that this issue does not reoccur.

import torch
import peft
from accelerate import notebook_launcher


def init():
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(1, 2)

def forward(self, x):
return self.linear(x)

model = MyModule().to("cuda")
peft.get_peft_model(model, peft.LoraConfig(target_modules=["linear"]))


def main():
notebook_launcher(init, (), num_processes=2)


if __name__ == "__main__":
main()
15 changes: 9 additions & 6 deletions src/peft/tuners/adalora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"]


if is_bnb_available():
from .bnb import SVDLinear8bitLt
def __getattr__(name):
if (name == "SVDLinear8bitLt") and is_bnb_available():
from .bnb import SVDLinear8bitLt

__all__ += ["SVDLinear8bitLt"]
return SVDLinear8bitLt

if is_bnb_4bit_available():
from .bnb import SVDLinear4bit
if (name == "SVDLinear4bit") and is_bnb_4bit_available():
from .bnb import SVDLinear4bit

__all__ += ["SVDLinear4bit"]
return SVDLinear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
16 changes: 8 additions & 8 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,6 @@
from .layer import AdaLoraLayer, RankAllocator, SVDLinear


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import SVDLinear8bitLt
if is_bnb_4bit_available():
from .bnb import SVDLinear4bit


class AdaLoraModel(LoraModel):
"""
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
Expand Down Expand Up @@ -159,6 +151,14 @@ def _create_and_replace(

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import SVDLinear8bitLt
if is_bnb_4bit_available():
from .bnb import SVDLinear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

Expand Down
15 changes: 9 additions & 6 deletions src/peft/tuners/ia3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
__all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"]


if is_bnb_available():
from .bnb import Linear8bitLt
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

__all__ += ["Linear8bitLt"]
return Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

__all__ += ["Linear4bit"]
return Linear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
18 changes: 9 additions & 9 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@
from .layer import Conv2d, IA3Layer, Linear


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit


class IA3Model(BaseTuner):
"""
Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained
Expand Down Expand Up @@ -86,6 +77,15 @@ def __init__(self, model, config, adapter_name):

@staticmethod
def _create_new_module(ia3_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit

loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
is_feedforward = kwargs.pop("is_feedforward", False)
Expand Down
15 changes: 9 additions & 6 deletions src/peft/tuners/lora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
__all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"]


if is_bnb_available():
from .bnb import Linear8bitLt
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt

__all__ += ["Linear8bitLt"]
return Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit

__all__ += ["Linear4bit"]
return Linear4bit

raise AttributeError(f"module {__name__} has no attribute {name}")
18 changes: 9 additions & 9 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@
from .layer import Conv2d, Embedding, Linear, LoraLayer


if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit


class LoraModel(BaseTuner):
"""
Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
Expand Down Expand Up @@ -253,6 +244,15 @@ def _mark_only_adapters_as_trainable(self) -> None:

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
# avoid eager bnb import
if is_bnb_available():
import bitsandbytes as bnb

from .bnb import Linear8bitLt

if is_bnb_4bit_available():
from .bnb import Linear4bit

gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pytest
import torch
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from datasets import Audio, DatasetDict, load_dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -933,3 +935,13 @@ def test_causal_lm_training_mutli_gpu(self):

# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])


@require_bitsandbytes
@require_torch_gpu
class MultiprocessTester(unittest.TestCase):
def test_notebook_launcher(self):
script_path = os.path.join("scripts", "launch_notebook_mp.py")
cmd = ["python", script_path]
with patch_environment(omp_num_threads=1):
run_command(cmd, env=os.environ.copy())

0 comments on commit b467e3d

Please sign in to comment.