Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

just import of peft leads to use of cuda and introduces cuda context that makes forks not possible #559

Closed
pseudotensor opened this issue Jun 9, 2023 · 24 comments

Comments

@pseudotensor
Copy link

pseudotensor commented Jun 9, 2023

Older peft did not do this, but newer does, at least since 3714aa2 and likely earlier.

This means if import in global scope as normal, no forks are possible anymore in python, ruining multiprocessing etc.

The use of cuda should be lazy on-demand and not forced when importing peft. That is, cuda should only be introduced when model itself is on cuda, not just from an import of peft.

jon@pseudotensor:~/h2ogpt$ python
Python 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from peft import PeftModel

===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /home/jon/miniconda3/envs/h2ollm/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so
/home/jon/miniconda3/envs/h2ollm/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/lib/jvm/default-java/jre/lib/amd64/server'), PosixPath('/opt/clang+llvm-4.0.0-x86_64-linux-gnu-ubuntu-16.04/lib'), PosixPath('/home/jon/lib'), PosixPath('/opt/rstudio-1.0.136/bin'), PosixPath('/usr/local/cuda/extras/CUPTI/lib64')}
  warn(msg)
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 121
CUDA SETUP: Loading binary /home/jon/miniconda3/envs/h2ollm/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so...
>>> 

Because of this, after such an import, things like this work:

from concurrent.futures import ProcessPoolExecutor


def go():
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained("h2oai/h2ogpt-oig-oasst1-512-6_9b", load_in_8bit=True, device_map={"": 'cuda'})
    assert model is not None


with ProcessPoolExecutor(max_workers=1) as executor:
    ret = executor.submit(go).result()

But this fails:

from peft import PeftModel
from concurrent.futures import ProcessPoolExecutor


def go():
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained("h2oai/h2ogpt-oig-oasst1-512-6_9b", load_in_8bit=True, device_map={"": 'cuda'})
    assert model is not None


with ProcessPoolExecutor(max_workers=1) as executor:
    ret = executor.submit(go).result()

with:

E               RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
@pseudotensor
Copy link
Author

pseudotensor commented Jun 9, 2023

Local scope import isn't good work-around:
h2oai/h2ogpt@762cdcd

because still if peft used, then same problems. Need cuda isolated to when model is put onto cuda and only then.

@younesbelkada
Copy link
Contributor

hi @pseudotensor
Thanks for raising the issue, could protecting the import of bnb related modules in PEFT with if torch.cuda.is_available() solve your issue?
Per my understanding the command

from peft import PeftModel

always led to importing bnb related modules if bnb is installed.

@pseudotensor
Copy link
Author

It is definitely new behavior. I can’t just wrap the import.

@younesbelkada
Copy link
Contributor

Thanks, I see, would you be able to share from which commit exactly this happens? I will also investigate on my side and let you know

@pseudotensor
Copy link
Author

I don't know exact one, but the above code can be used to bisect.

@pseudotensor
Copy link
Author

@younesbelkada Any update? Seems should be easy to fix.

@younesbelkada
Copy link
Contributor

Hi @pseudotensor
I still didn't find time to properly address this,
Also would love to hear @BenjaminBossan 's thoughts here in case I missed few things

@BenjaminBossan
Copy link
Member

I ran a git bisect and the offending commit is this one: d75746b

So it is indeed the top level import of bnb that causes the issue. Commenting out the import fixes the it.

With the current state of the code base, it might be possible to prevent any top level imports of bnb, but it wouldn't be trivial. I do wonder, however, if it would be possible for bnb to make a change to avoid the issue. I don't know that library well, so maybe someone else can comment on that?

@pseudotensor
Copy link
Author

Hi, but with that change, peft now depends fully on bitsandbytes, even though it is just one component. bitsandbytes is not trivial to install on every system, e.g. windows or mac. So this limits peft quite a bit unless fixed inside peft itself.

@younesbelkada
Copy link
Contributor

younesbelkada commented Jul 6, 2023

I agree on

bitsandbytes is not trivial to install on every system, e.g. windows or mac.

We should probably think of making bitsandbytes an optional dependency, and this should fix this issue. wdyt @pacman100 ?

@pseudotensor
Copy link
Author

pseudotensor commented Jul 6, 2023

But that won't solve the issue here, because bitsandbytes messes up multiprocessing due to cuda import. I should be able to use CUDA and have bitsandbytes, but not have import of peft necessarily load bitsandbytes globally.

bitsandbytes should only be loaded when the model object requires it, not at global import time.

@pseudotensor
Copy link
Author

pseudotensor commented Jul 6, 2023

A simple fix for peft is to put the Linear8bitLt class in a separate file and only import it locally. Then it will be inside LoraModel and never cause any problems, because it is not imported until the LoraModel is created (i.e. model-time import).

@BenjaminBossan
Copy link
Member

A simple fix for peft is to put the class in a separate file and only import it locally

As I mentioned, I think it should be possible to load bnb lazily, what you suggest is one possibility. But it isn't a trivial change, for instance, we would have to ensure in our test suite that nothing breaks when bnb is installed vs not installed.

I should be able to use CUDA and have bitsandbytes

Assuming you do have bnb, then this is irrelevant, right?

bitsandbytes is not trivial to install on every system, e.g. windows or mac

That last part should be fixed by making bnb an optional dependency, right?

IMO if bnb could do something about this issue, it would still be a win (but I don't know how trivial it is for them to fix it).

@pseudotensor
Copy link
Author

pseudotensor commented Jul 6, 2023

I presume you have tests of LORA that uses bitsandbytes, so such local imports would be tested.

In order to test that bitsandbytes is no longer breaking anything, just add the trivial repro I provided in this issue to your testing.

I don't see these as non-trivial.

@BenjaminBossan
Copy link
Member

Maybe I'm missing something, but what I mean is that to test that the dependency on bnb is indeed optional, we have to create an env in our test setup (or mock the existing one) to pretend that bnb is not installed and run all the tests (except the bnb-specific ones) to ensure that they still pass without bnb installed. Otherwise, we might have code that does depend on bnb when it shouldn't, but if bnb is installed in the test env, we wouldn't notice it. Does that make sense?

@BenjaminBossan
Copy link
Member

Btw as a workaround until we have a fix, you should be able to patch builtins.__import__, intercept it if bnb is trying to be imported, and return a mock object.

@pseudotensor
Copy link
Author

Maybe I'm missing something, but what I mean is that to test that the dependency on bnb is indeed optional, we have to create an env in our test setup (or mock the existing one) to pretend that bnb is not installed and run all the tests (except the bnb-specific ones) to ensure that they still pass without bnb installed. Otherwise, we might have code that does depend on bnb when it shouldn't, but if bnb is installed in the test env, we wouldn't notice it. Does that make sense?

This would be an improvement over even the situation prior to when the problem started and would be a nice to have. The fixes I suggested I think are critical.

@pseudotensor
Copy link
Author

Btw as a workaround until we have a fix, you should be able to patch builtins.__import__, intercept it if bnb is trying to be imported, and return a mock object.

I still want to use bitsandbytes, I just don't want it imported early and contaminate the global scope with a CUDA context making forking and doing CUDA tasks not possible (original issue I reported).

@BenjaminBossan
Copy link
Member

I still want to use bitsandbytes, I just don't want it imported early and contaminate the global scope with a CUDA context making forking and doing CUDA tasks not possible (original issue I reported).

Ah I see, in that case this suggestion wouldn't actually work.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this as completed Aug 8, 2023
@pseudotensor
Copy link
Author

pseudotensor commented Aug 18, 2023

Any updates? Shouldn't be closed IMO.

@younesbelkada younesbelkada reopened this Aug 18, 2023
@younesbelkada
Copy link
Contributor

Yes, re-opened we're still discussing this internally

@huggingface huggingface deleted a comment from github-actions bot Sep 13, 2023
@huggingface huggingface deleted a comment from github-actions bot Oct 18, 2023
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@BenjaminBossan
Copy link
Member

Sorry that this took so long. #1230 should have fixed the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants