diff --git a/scripts/launch_notebook_mp.py b/scripts/launch_notebook_mp.py new file mode 100644 index 0000000000..56f6ddbd13 --- /dev/null +++ b/scripts/launch_notebook_mp.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is a minimal example of launching PEFT with Accelerate. This used to cause issues because PEFT would eagerly +# import bitsandbytes, which initializes CUDA, resulting in: +# > RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the +# > 'spawn' start method +# This script exists to ensure that this issue does not reoccur. + +import torch +import peft +from accelerate import notebook_launcher + + +def init(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(1, 2) + + def forward(self, x): + return self.linear(x) + + model = MyModule().to("cuda") + peft.get_peft_model(model, peft.LoraConfig(target_modules=["linear"])) + + +def main(): + notebook_launcher(init, (), num_processes=2) + + +if __name__ == "__main__": + main() diff --git a/src/peft/tuners/adalora/__init__.py b/src/peft/tuners/adalora/__init__.py index 2012c6ca24..431ac44066 100644 --- a/src/peft/tuners/adalora/__init__.py +++ b/src/peft/tuners/adalora/__init__.py @@ -24,12 +24,15 @@ __all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "SVDLinear", "RankAllocator", "SVDQuantLinear"] -if is_bnb_available(): - from .bnb import SVDLinear8bitLt +def __getattr__(name): + if (name == "SVDLinear8bitLt") and is_bnb_available(): + from .bnb import SVDLinear8bitLt - __all__ += ["SVDLinear8bitLt"] + return SVDLinear8bitLt -if is_bnb_4bit_available(): - from .bnb import SVDLinear4bit + if (name == "SVDLinear4bit") and is_bnb_4bit_available(): + from .bnb import SVDLinear4bit - __all__ += ["SVDLinear4bit"] + return SVDLinear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index 7ccf13e8c9..9c2e6b7e56 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -33,14 +33,6 @@ from .layer import AdaLoraLayer, RankAllocator, SVDLinear -if is_bnb_available(): - import bitsandbytes as bnb - - from .bnb import SVDLinear8bitLt -if is_bnb_4bit_available(): - from .bnb import SVDLinear4bit - - class AdaLoraModel(LoraModel): """ Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: @@ -159,6 +151,14 @@ def _create_and_replace( @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import SVDLinear8bitLt + if is_bnb_4bit_available(): + from .bnb import SVDLinear4bit + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) diff --git a/src/peft/tuners/ia3/__init__.py b/src/peft/tuners/ia3/__init__.py index 905e83758f..44100190b5 100644 --- a/src/peft/tuners/ia3/__init__.py +++ b/src/peft/tuners/ia3/__init__.py @@ -23,12 +23,15 @@ __all__ = ["Conv2d", "IA3Config", "IA3Layer", "IA3Model", "Linear"] -if is_bnb_available(): - from .bnb import Linear8bitLt +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt - __all__ += ["Linear8bitLt"] + return Linear8bitLt -if is_bnb_4bit_available(): - from .bnb import Linear4bit + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit - __all__ += ["Linear4bit"] + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 409d1bdcd5..a6617e6ef8 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -35,15 +35,6 @@ from .layer import Conv2d, IA3Layer, Linear -if is_bnb_available(): - import bitsandbytes as bnb - - from .bnb import Linear8bitLt - -if is_bnb_4bit_available(): - from .bnb import Linear4bit - - class IA3Model(BaseTuner): """ Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained @@ -86,6 +77,15 @@ def __init__(self, model, config, adapter_name): @staticmethod def _create_new_module(ia3_config, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import Linear8bitLt + + if is_bnb_4bit_available(): + from .bnb import Linear4bit + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) is_feedforward = kwargs.pop("is_feedforward", False) diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index ddc81d53cd..1eb90b3ccf 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -24,12 +24,15 @@ __all__ = ["LoraConfig", "LoftQConfig", "Conv2d", "Embedding", "LoraLayer", "Linear", "LoraModel", "QuantLinear"] -if is_bnb_available(): - from .bnb import Linear8bitLt +def __getattr__(name): + if (name == "Linear8bitLt") and is_bnb_available(): + from .bnb import Linear8bitLt - __all__ += ["Linear8bitLt"] + return Linear8bitLt -if is_bnb_4bit_available(): - from .bnb import Linear4bit + if (name == "Linear4bit") and is_bnb_4bit_available(): + from .bnb import Linear4bit - __all__ += ["Linear4bit"] + return Linear4bit + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 778d5a9ff7..f6b6453b62 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -45,15 +45,6 @@ from .layer import Conv2d, Embedding, Linear, LoraLayer -if is_bnb_available(): - import bitsandbytes as bnb - - from .bnb import Linear8bitLt - -if is_bnb_4bit_available(): - from .bnb import Linear4bit - - class LoraModel(BaseTuner): """ Creates Low Rank Adapter (LoRA) model from a pretrained transformers model. @@ -253,6 +244,15 @@ def _mark_only_adapters_as_trainable(self) -> None: @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import Linear8bitLt + + if is_bnb_4bit_available(): + from .bnb import Linear4bit + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index 1af1919ad3..e99d6c414d 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -21,6 +21,8 @@ import pytest import torch +from accelerate.test_utils.testing import run_command +from accelerate.utils import patch_environment from datasets import Audio, DatasetDict, load_dataset from transformers import ( AutoModelForCausalLM, @@ -933,3 +935,13 @@ def test_causal_lm_training_mutli_gpu(self): # assert loss is not None self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + +@require_bitsandbytes +@require_torch_gpu +class MultiprocessTester(unittest.TestCase): + def test_notebook_launcher(self): + script_path = os.path.join("scripts", "launch_notebook_mp.py") + cmd = ["python", script_path] + with patch_environment(omp_num_threads=1): + run_command(cmd, env=os.environ.copy())