20
20
import math
21
21
import os
22
22
import warnings
23
+ from contextlib import contextmanager
23
24
from typing import Any , Optional , Union
24
25
25
26
import torch
31
32
from peft .tuners .tuners_utils import BaseTunerLayer , check_adapters_to_merge
32
33
33
34
34
- os .environ ["CC" ] = "gcc"
35
- os .environ ["CXX" ] = "gcc"
36
- curr_dir = os .path .dirname (__file__ )
37
-
38
35
_FBD_CUDA = None
39
36
40
37
38
+ # this function is a 1:1 copy from accelerate
39
+ @contextmanager
40
+ def patch_environment (** kwargs ):
41
+ """
42
+ A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
43
+
44
+ Will convert the values in `kwargs` to strings and upper-case all the keys.
45
+
46
+ Example:
47
+
48
+ ```python
49
+ >>> import os
50
+ >>> from accelerate.utils import patch_environment
51
+
52
+ >>> with patch_environment(FOO="bar"):
53
+ ... print(os.environ["FOO"]) # prints "bar"
54
+ >>> print(os.environ["FOO"]) # raises KeyError
55
+ ```
56
+ """
57
+ existing_vars = {}
58
+ for key , value in kwargs .items ():
59
+ key = key .upper ()
60
+ if key in os .environ :
61
+ existing_vars [key ] = os .environ [key ]
62
+ os .environ [key ] = str (value )
63
+
64
+ yield
65
+
66
+ for key in kwargs :
67
+ key = key .upper ()
68
+ if key in existing_vars :
69
+ # restore previous value
70
+ os .environ [key ] = existing_vars [key ]
71
+ else :
72
+ os .environ .pop (key , None )
73
+
74
+
41
75
def get_fbd_cuda ():
42
76
global _FBD_CUDA
43
77
@@ -47,14 +81,15 @@ def get_fbd_cuda():
47
81
curr_dir = os .path .dirname (__file__ )
48
82
# need ninja to build the extension
49
83
try :
50
- fbd_cuda = load (
51
- name = "fbd_cuda" ,
52
- sources = [f"{ curr_dir } /fbd/fbd_cuda.cpp" , f"{ curr_dir } /fbd/fbd_cuda_kernel.cu" ],
53
- verbose = True ,
54
- # build_directory='/tmp/' # for debugging
55
- )
56
- # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
57
- import fbd_cuda
84
+ with patch_environment (CC = "gcc" , CXX = "gcc" ):
85
+ fbd_cuda = load (
86
+ name = "fbd_cuda" ,
87
+ sources = [f"{ curr_dir } /fbd/fbd_cuda.cpp" , f"{ curr_dir } /fbd/fbd_cuda_kernel.cu" ],
88
+ verbose = True ,
89
+ # build_directory='/tmp/' # for debugging
90
+ )
91
+ # extra_cuda_cflags = ['-std=c++14', '-ccbin=$$(which gcc-7)']) # cuda10.2 is not compatible with gcc9. Specify gcc 7
92
+ import fbd_cuda
58
93
except Exception as e :
59
94
warnings .warn (f"Failed to load the CUDA extension: { e } , check if ninja is available." )
60
95
warnings .warn ("Setting boft_n_butterfly_factor to 1 to speed up the finetuning process." )
0 commit comments