Skip to content

Commit e3eeabf

Browse files
FIX BOFT setting env vars breaks C++ compilation (#1739)
Resolves #1738
1 parent ae1ae20 commit e3eeabf

File tree

1 file changed

+47
-12
lines changed

1 file changed

+47
-12
lines changed

src/peft/tuners/boft/layer.py

+47-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import math
2121
import os
2222
import warnings
23+
from contextlib import contextmanager
2324
from typing import Any, Optional, Union
2425

2526
import torch
@@ -31,13 +32,46 @@
3132
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
3233

3334

34-
os.environ["CC"] = "gcc"
35-
os.environ["CXX"] = "gcc"
36-
curr_dir = os.path.dirname(__file__)
37-
3835
_FBD_CUDA = None
3936

4037

38+
# this function is a 1:1 copy from accelerate
39+
@contextmanager
40+
def patch_environment(**kwargs):
41+
"""
42+
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
43+
44+
Will convert the values in `kwargs` to strings and upper-case all the keys.
45+
46+
Example:
47+
48+
```python
49+
>>> import os
50+
>>> from accelerate.utils import patch_environment
51+
52+
>>> with patch_environment(FOO="bar"):
53+
... print(os.environ["FOO"]) # prints "bar"
54+
>>> print(os.environ["FOO"]) # raises KeyError
55+
```
56+
"""
57+
existing_vars = {}
58+
for key, value in kwargs.items():
59+
key = key.upper()
60+
if key in os.environ:
61+
existing_vars[key] = os.environ[key]
62+
os.environ[key] = str(value)
63+
64+
yield
65+
66+
for key in kwargs:
67+
key = key.upper()
68+
if key in existing_vars:
69+
# restore previous value
70+
os.environ[key] = existing_vars[key]
71+
else:
72+
os.environ.pop(key, None)
73+
74+
4175
def get_fbd_cuda():
4276
global _FBD_CUDA
4377

@@ -47,14 +81,15 @@ def get_fbd_cuda():
4781
curr_dir = os.path.dirname(__file__)
4882
# need ninja to build the extension
4983
try:
50-
fbd_cuda = load(
51-
name="fbd_cuda",
52-
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
53-
verbose=True,
54-
# build_directory='/tmp/' # for debugging
55-
)
56-
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
57-
import fbd_cuda
84+
with patch_environment(CC="gcc", CXX="gcc"):
85+
fbd_cuda = load(
86+
name="fbd_cuda",
87+
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
88+
verbose=True,
89+
# build_directory='/tmp/' # for debugging
90+
)
91+
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
92+
import fbd_cuda
5893
except Exception as e:
5994
warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.")
6095
warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.")

0 commit comments

Comments
 (0)