Skip to content

Commit

Permalink
FIX BOFT setting env vars breaks C++ compilation (#1739)
Browse files Browse the repository at this point in the history
Resolves #1738
  • Loading branch information
BenjaminBossan committed May 17, 2024
1 parent 0649947 commit 2276c6e
Showing 1 changed file with 47 additions and 12 deletions.
59 changes: 47 additions & 12 deletions src/peft/tuners/boft/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import os
import warnings
from contextlib import contextmanager
from typing import Any, Optional, Union

import torch
Expand All @@ -31,13 +32,46 @@
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge


os.environ["CC"] = "gcc"
os.environ["CXX"] = "gcc"
curr_dir = os.path.dirname(__file__)

_FBD_CUDA = None


# this function is a 1:1 copy from accelerate
@contextmanager
def patch_environment(**kwargs):
"""
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
Will convert the values in `kwargs` to strings and upper-case all the keys.
Example:
```python
>>> import os
>>> from accelerate.utils import patch_environment
>>> with patch_environment(FOO="bar"):
... print(os.environ["FOO"]) # prints "bar"
>>> print(os.environ["FOO"]) # raises KeyError
```
"""
existing_vars = {}
for key, value in kwargs.items():
key = key.upper()
if key in os.environ:
existing_vars[key] = os.environ[key]
os.environ[key] = str(value)

yield

for key in kwargs:
key = key.upper()
if key in existing_vars:
# restore previous value
os.environ[key] = existing_vars[key]
else:
os.environ.pop(key, None)


def get_fbd_cuda():
global _FBD_CUDA

Expand All @@ -47,14 +81,15 @@ def get_fbd_cuda():
curr_dir = os.path.dirname(__file__)
# need ninja to build the extension
try:
fbd_cuda = load(
name="fbd_cuda",
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
verbose=True,
# build_directory='/tmp/' # for debugging
)
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
import fbd_cuda
with patch_environment(CC="gcc", CXX="gcc"):
fbd_cuda = load(
name="fbd_cuda",
sources=[f"{curr_dir}/fbd/fbd_cuda.cpp", f"{curr_dir}/fbd/fbd_cuda_kernel.cu"],
verbose=True,
# build_directory='/tmp/' # for debugging
)
# extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
import fbd_cuda
except Exception as e:
warnings.warn(f"Failed to load the CUDA extension: {e}, check if ninja is available.")
warnings.warn("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process.")
Expand Down

0 comments on commit 2276c6e

Please sign in to comment.