diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index a19e7511d..ce4a55aaa 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -13,7 +13,9 @@ jobs: with: commit_sha: ${{ github.sha }} package: bitsandbytes - repo_owner: TimDettmers + repo_owner: bitsandbytes-foundation + # avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/ + version_tag_suffix: '' # defaults to '/src' custom_container: huggingface/transformers-doc-builder secrets: hf_token: ${{ secrets.HUGGINGFACE_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index cc833df5d..4679761c6 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -9,11 +9,13 @@ concurrency: jobs: build: - if: github.repository == 'TimDettmers/bitsandbytes' + if: github.repository == 'bitsandbytes-foundation/bitsandbytes' uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} package: bitsandbytes - repo_owner: TimDettmers + repo_owner: bitsandbytes-foundation + # avoid /src suffix leading to wrong links, like bitsandbytes/blob/main/src/bitsandbytes/nn/ + version_tag_suffix: '' # defaults to '/src' custom_container: huggingface/transformers-doc-builder diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 78bc747c3..671dfee1c 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -63,12 +63,10 @@ jobs: os: [ubuntu-latest, windows-latest] arch: [x86_64, aarch64] cuda_version: - ["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.0"] + ["11.7.1", "11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.0"] exclude: - os: windows-latest # This probably requires arm64 Windows agents arch: aarch64 - - os: windows-latest # The Jimver/cuda-toolkit is action used for Windows builds is not updated for 12.4 yet. - cuda_version: "12.4.0" - os: ubuntu-latest # Temporary. Takes too long, not ready yet. arch: aarch64 runs-on: ${{ matrix.os }} # One day, we could run them on native agents. Azure supports this now but it's planned only for Q3 2023 for hosted agents @@ -79,7 +77,7 @@ jobs: if: startsWith(matrix.os, 'ubuntu') uses: docker/setup-qemu-action@v2 # Windows: We install Cuda on the agent (slow) - - uses: Jimver/cuda-toolkit@v0.2.14 + - uses: Jimver/cuda-toolkit@v0.2.16 if: startsWith(matrix.os, 'windows') id: cuda-toolkit with: diff --git a/CHANGELOG.md b/CHANGELOG.md index c456fa9e5..e446155b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,29 @@ +### 0.43.2 + +This release is quite significant as the QLoRA bug fix big implications for higher `seqlen` and batch sizes. + +For each sequence (i.e. batch size increase of one) we expect memory savings of: +- 405B: 39GB for seqlen 1024, and 4888GB for 128k +- 70B: 20.1GB for 1024 and 2516GB for 128k + +This was due to activations being unnecessary for frozen parameters, yet the memory for them was still erroneously allocated due to the now fixed bug. + +#### Improvements: + +- docs: FSDP+QLoRA and CPU install guide (#1211 #1227, thanks @stevhliu) +- Add CUDA 12.5 and update 12.4 builds (#1284) + +#### Bug Fixes + +- 4bit getstate and 8bit deepcopy (#1230 #1231, thanks @BenjaminBossan) +- missing optimizers in `str2optimizer32bit` (#1222, thanks @EtienneDosSantos) +- CUDA 12.5 build issue (#1273, thanks @HennerM) +- fix for min_8bit_size functionality in Optimizer base classes (#1286, thanks @Edenzzzz) +- QLoRA mem bug (#1270, thanks @Ther-nullptr) +- tests for cpu only platforms (#1259, thanks @galqiwi) +- restoration of quant_storage for CPU offloading (#1279) +- optim update error with non-contiguous grads/params (deepspeed) (#1187) + ### 0.43.1 #### Improvements: diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bedefd51..ec48b9d97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,13 @@ endif() if(BUILD_CUDA) + # NVCC normally will only work with MSVC up to 1939. VS2022 17.10+ starts using versions 1940+. + # Workaround: use --allow-unsupported-compiler + # This needs to be added *before* we try to enable the CUDA language so CMake's compiler check passes. + if(MSVC AND MSVC_VERSION VERSION_GREATER_EQUAL 1940) + string(APPEND CMAKE_CUDA_FLAGS " --allow-unsupported-compiler") + endif() + enable_language(CUDA) # This will fail if CUDA is not found find_package(CUDAToolkit REQUIRED) @@ -229,7 +236,6 @@ if(WIN32) set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) endif() -# Weird MSVC hacks if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast") endif() diff --git a/_typos.toml b/_typos.toml index a04206b8d..e4e7287fb 100644 --- a/_typos.toml +++ b/_typos.toml @@ -1,5 +1,10 @@ [files] +[default] +extend-ignore-re = [ + "@Ther-nul", # valid Github user +] + [default.extend-identifiers] [type.py.extend-words] diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c3a2f2402..129ac1536 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -74,4 +74,4 @@ "optim.optimizer.MockArgs": False, } -__version__ = "0.43.2.dev" +__version__ = "0.43.3.dev" diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 8e296a8ee..59e26ad09 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -524,7 +524,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (A, B) + ctx.tensors = (None, B) else: ctx.tensors = (None, None) @@ -537,7 +537,7 @@ def backward(ctx, grad_output): return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad - A, B = ctx.tensors + _, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2041589b3..6cf64df28 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -27,6 +27,35 @@ def prod(iterable): if lib and lib.compiled_with_cuda: """C FUNCTIONS FOR OPTIMIZERS""" + str2optimizer32bit = { + "adam": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + lib.cadam32bit_grad_bf16, + ), + "momentum": ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ), + "rmsprop": ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ), + "lion": ( + lib.clion32bit_grad_fp32, + lib.clion32bit_grad_fp16, + lib.clion32bit_grad_bf16, + ), + "adagrad": ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ), + "lamb": ( + lib.cadam32bit_grad_fp32, + lib.cadam32bit_grad_fp16, + ), + } + str2optimizer8bit = { "adam": ( lib.cadam_static_8bit_grad_32, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7ab070785..c92b25e2c 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -240,7 +240,7 @@ def __new__( return self def __getstate__(self): - state = self.__dict__ + state = self.__dict__.copy() state["data"] = self.data state["requires_grad"] = self.requires_grad return state @@ -286,6 +286,9 @@ def from_prequantized( self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type self.bnb_quantized = True + + self.quant_storage = data.dtype + return self def _quantize(self, device): @@ -340,6 +343,7 @@ def to(self, *args, **kwargs): blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, + quant_storage=self.quant_storage, ) return new_param @@ -457,7 +461,7 @@ def forward(self, x: torch.Tensor): # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) + self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) self.weight.quant_state = self.quant_state else: print( @@ -567,13 +571,12 @@ def __new__( CB=None, SCB=None, ): - cls.has_fp16_weights = has_fp16_weights - cls.CB = None - cls.SCB = None if data is None: data = torch.empty(0) obj = torch.Tensor._make_subclass(cls, data, requires_grad) - obj.CB, obj.SCB = cls.CB, cls.SCB + obj.CB = CB + obj.SCB = SCB + obj.has_fp16_weights = has_fp16_weights return obj def cuda(self, device): @@ -592,6 +595,18 @@ def cuda(self, device): return self + def __deepcopy__(self, memo): + # adjust this if new arguments are added to the constructor + new_instance = type(self).__new__( + type(self), + data=copy.deepcopy(self.data, memo), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + CB=copy.deepcopy(self.CB, memo), + SCB=copy.deepcopy(self.SCB, memo), + ) + return new_instance + def cpu(self): # we store the 8-bit rows-major weight B = self.data.contiguous().bfloat16().cpu() diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index f1e60e5e7..e9c857d49 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -437,7 +437,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: @@ -474,6 +474,10 @@ def init_state(self, group, p, gindex, pindex): @torch.no_grad() def update_step(self, group, p, gindex, pindex): + # avoid update error from non-contiguous memory layout + p.data = p.data.contiguous() + p.grad = p.grad.contiguous() + state = self.state[p] grad = p.grad @@ -656,7 +660,7 @@ def init_state(self, group, p, gindex, pindex): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32: state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: @@ -685,6 +689,10 @@ def init_state(self, group, p, gindex, pindex): @torch.no_grad() def update_step(self, group, p, gindex, pindex): + # avoid update error from non-contiguous memory layout + p.data = p.data.contiguous() + p.grad = p.grad.contiguous() + state = self.state[p] grad = p.grad diff --git a/csrc/kernels.cu b/csrc/kernels.cu index f4673359b..e4d459961 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -12,8 +12,6 @@ #include #include #include -#include -#include #include diff --git a/csrc/ops.cuh b/csrc/ops.cuh index da9df6af0..8b9a4f449 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -19,10 +19,6 @@ #include #include -#include -#include - - #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index c07ef29f6..6cac47fa9 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,8 +1,10 @@ # Installation -bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.3**. +## CUDA -The latest version of bitsandbytes (v0.43.0) builds on: +bitsandbytes is only supported on CUDA GPUs for CUDA versions **11.0 - 12.5**. There's a multi-backend effort under way which is currently in alpha release, see further down in this document. + +The latest version of bitsandbytes builds on: | OS | CUDA | Compiler | |---|---|---| @@ -29,7 +31,7 @@ To install from PyPI. pip install bitsandbytes ``` -## Compile from source +### Compile from source For Linux and Windows systems, you can compile bitsandbytes from source. Installing from source allows for more build options with different CMake configurations. @@ -59,7 +61,7 @@ git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ pip install -r requirements-dev.txt cmake -DCOMPUTE_BACKEND=cuda -S . make -pip install . +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` > [!TIP] @@ -83,58 +85,30 @@ git clone https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ pip install -r requirements-dev.txt cmake -DCOMPUTE_BACKEND=cuda -S . cmake --build . --config Release -python -m build --wheel +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) ``` Big thanks to [wkpark](https://github.com/wkpark), [Jamezo97](https://github.com/Jamezo97), [rickardp](https://github.com/rickardp), [akx](https://github.com/akx) for their amazing contributions to make bitsandbytes compatible with Windows. - - -## Multi-backend preview release (+ compilation) - -Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA: - - - - -For a ROCm specific install: + -bitsandbytes is fully supported from ROCm 6.1. +Windows systems require Visual Studio with C++ support. -**Note:** If you already installed ROCm and PyTorch, skip docker steps below and please check that the torch version matches your ROCm install. To install torch for a specific ROCm version, please refer to step 3 of wheels install in [Installing PyTorch for ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/3rd-party/pytorch-install.html#using-wheels-package) guide. +To compile from source, you need CMake >= **3.22.1** and Python >= **3.8** installed. ```bash -# Create a docker container with latest pytorch. It comes with ROCm and pytorch preinstalled -docker pull rocm/pytorch:latest -docker run -it --device=/dev/kfd --device=/dev/dri --group-add video rocm/pytorch:latest - -# Clone bitsandbytes repo, ROCm backend is currently enabled on multi-backend-refactor branch -git clone --depth 1 -b multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ - -# Install dependencies +git clone --branch multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ pip install -r requirements-dev.txt - -# Compile & install -cmake -DCOMPUTE_BACKEND=hip -S . # Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch -make +cmake -DCOMPUTE_BACKEND=cpu -S . +cmake --build . --config Release pip install . ``` - - - -WIP - - - - -WIP - -## PyTorch CUDA versions +### PyTorch CUDA versions Some bitsandbytes features may need a newer CUDA version than the one currently supported by PyTorch binaries from Conda and pip. In this case, you should follow these instructions to load a precompiled bitsandbytes binary. @@ -148,7 +122,7 @@ Then locally install the CUDA version you need with this script from bitsandbyte ```bash wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh # Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH -# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124} +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125} # EXPORT_TO_BASH in {0, 1} with 0=False and 1=True # For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc @@ -170,7 +144,57 @@ For example, to use a local install path: ```bash export BNB_CUDA_VERSION=117 -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/tim/local/cuda-11.7 +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-11.7 ``` 3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. + +## Multi-backend preview release (+ compilation) + +Please follow these steps to install bitsandbytes with device-specific backend support other than CUDA: + + + + +### AMD GPU + +For a ROCm specific install: + +bitsandbytes is fully supported from ROCm 6.1. + +**Note:** If you already installed ROCm and PyTorch, skip docker steps below and please check that the torch version matches your ROCm install. To install torch for a specific ROCm version, please refer to step 3 of wheels install in [Installing PyTorch for ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/3rd-party/pytorch-install.html#using-wheels-package) guide. + +```bash +# Create a docker container with latest pytorch. It comes with ROCm and pytorch preinstalled +docker pull rocm/pytorch:latest +docker run -it --device=/dev/kfd --device=/dev/dri --group-add video rocm/pytorch:latest + +# Clone bitsandbytes repo, ROCm backend is currently enabled on multi-backend-refactor branch +git clone --depth 1 -b multi-backend-refactor https://github.com/TimDettmers/bitsandbytes.git && cd bitsandbytes/ + +# Install dependencies +pip install -r requirements-dev.txt + +# Compile & install +cmake -DCOMPUTE_BACKEND=hip -S . # Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu arch +make +pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) +``` + + + + +### Intel CPU + +> [!TIP] +> Intel CPU backend only supports building from source; for now, please follow the instructions below. + +Like CUDA, you can compile bitsandbytes from source for Linux and Windows systems. Installing from source allows for more build options with different CMake configurations. + + + + +WIP + + + diff --git a/install_cuda.py b/install_cuda.py index cf7c8ee71..8267c5e2b 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -17,7 +17,8 @@ "121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run", "122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run", "123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run", - "124": "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run", + "124": "https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run", + "125": "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.42.02_linux.run", } diff --git a/install_cuda.sh b/install_cuda.sh index 2e7fe8ed2..0aa9531fc 100644 --- a/install_cuda.sh +++ b/install_cuda.sh @@ -11,7 +11,8 @@ URL120=https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installer URL121=https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run URL122=https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run URL123=https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run -URL124=https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run +URL124=https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run +URL125=https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.42.02_linux.run CUDA_VERSION=$1 BASE_PATH=$2 @@ -60,11 +61,14 @@ if [[ -n "$CUDA_VERSION" ]]; then elif [[ "$CUDA_VERSION" -eq "124" ]]; then URL=$URL124 FOLDER=cuda-12.4 + elif [[ "$CUDA_VERSION" -eq "125" ]]; then + URL=$URL125 + FOLDER=cuda-12.5 else - echo "argument error: No cuda version passed as input. Choose among versions 110 to 124" + echo "argument error: No cuda version passed as input. Choose among versions 110 to 125" fi else - echo "argument error: No cuda version passed as input. Choose among versions 92 to 123" + echo "argument error: No cuda version passed as input. Choose among versions 110 to 125" fi FILE=$(basename $URL) diff --git a/requirements-ci.txt b/requirements-ci.txt index 0e9dd2407..182e1023e 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -1,6 +1,6 @@ # Requirements used for GitHub actions -pytest==8.2.1 +pytest==8.3.1 einops==0.8.0 -lion-pytorch==0.1.4 +lion-pytorch==0.2.2 scipy==1.10.1; python_version < "3.9" -scipy==1.13.1; python_version >= "3.9" +scipy==1.14.0; python_version >= "3.9" diff --git a/requirements-dev.txt b/requirements-dev.txt index de7adce94..41211880c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,9 @@ # Requirements used for local development setuptools>=63 -pytest~=8.2.1 +pytest~=8.3.1 einops~=0.8.0 wheel~=0.43.0 -lion-pytorch~=0.1.4 -scipy~=1.13.1 +lion-pytorch~=0.2.2 +scipy~=1.14.0 pandas~=2.2.2 -matplotlib~=3.9.0 +matplotlib~=3.9.1 diff --git a/setup.py b/setup.py index f8d6a92a1..18de0fe5b 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ def has_ext_modules(self): setup( name="bitsandbytes", - version="0.43.2.dev", + version="0.43.3.dev", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="k-bit optimizers and matrix multiplication routines.", diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index bbbd05335..2f094be27 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -186,19 +186,30 @@ def test_copy_param(): def test_deepcopy_param(): tensor = torch.tensor([1.0, 2.0, 3.0, 4.0]) param = bnb.nn.Params4bit(data=tensor, requires_grad=False).cuda(0) + dict_keys_before = set(param.__dict__.keys()) copy_param = copy.deepcopy(param) + dict_keys_after = set(param.__dict__.keys()) + dict_keys_copy = set(copy_param.__dict__.keys()) + assert param.quant_state is not copy_param.quant_state assert param.data.data_ptr() != copy_param.data.data_ptr() + # there was a bug where deepcopy would modify the original object + assert dict_keys_before == dict_keys_after + assert dict_keys_before == dict_keys_copy + def test_params4bit_real_serialization(): original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) original_param = bnb.nn.Params4bit(data=original_tensor, quant_type="fp4") + dict_keys_before = set(original_param.__dict__.keys()) original_param.cuda(0) # move to CUDA to trigger quantization serialized_param = pickle.dumps(original_param) deserialized_param = pickle.loads(serialized_param) + dict_keys_after = set(original_param.__dict__.keys()) + dict_keys_deserialized = set(deserialized_param.__dict__.keys()) assert torch.equal(original_param.data, deserialized_param.data) assert original_param.requires_grad == deserialized_param.requires_grad == False @@ -206,3 +217,7 @@ def test_params4bit_real_serialization(): assert original_param.blocksize == deserialized_param.blocksize assert original_param.compress_statistics == deserialized_param.compress_statistics assert original_param.quant_state == deserialized_param.quant_state + + # there was a bug where deepcopy would modify the original object + assert dict_keys_before == dict_keys_after + assert dict_keys_before == dict_keys_deserialized diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 2a4bd02e2..c4409cc2e 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -1,5 +1,7 @@ from contextlib import nullcontext +import copy import os +import pickle from tempfile import TemporaryDirectory import pytest @@ -181,3 +183,59 @@ def test_linear_serialization( assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) assert torch.allclose(fx_first, fx_third, atol=1e-5) assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) + + +@pytest.fixture +def linear8bit(requires_cuda): + linear = torch.nn.Linear(32, 96) + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), + requires_grad=False, + has_fp16_weights=False, + ) + linear_custom.bias = linear.bias + linear_custom = linear_custom.cuda() + return linear_custom + + +def test_linear8bit_copy_param(linear8bit): + shallow_copy = copy.copy(linear8bit) + assert linear8bit.weight is shallow_copy.weight + assert linear8bit.bias is shallow_copy.bias + assert linear8bit.weight.data.data_ptr() == shallow_copy.weight.data.data_ptr() + + +def test_linear8bit_deepcopy_param(linear8bit): + deep_copy = copy.deepcopy(linear8bit) + assert linear8bit.weight is not deep_copy.weight + assert linear8bit.bias is not deep_copy.bias + assert linear8bit.weight.data.data_ptr() != deep_copy.weight.data.data_ptr() + assert torch.allclose(linear8bit.weight.data, deep_copy.weight.data) + assert linear8bit.state == deep_copy.state + + # check for a bug where SCB and CB were not copied + assert deep_copy.weight.SCB is not None + assert (linear8bit.weight.SCB == deep_copy.weight.SCB).all() + assert deep_copy.weight.CB is not None + assert (linear8bit.weight.CB == deep_copy.weight.CB).all() + + +def test_linear8bit_serialization(linear8bit): + serialized = pickle.dumps(linear8bit) + deserialized = pickle.loads(serialized) + assert linear8bit.weight.data.data_ptr() != deserialized.weight.data.data_ptr() + assert torch.allclose(linear8bit.weight.data, deserialized.weight.data) + assert linear8bit.bias.data.data_ptr() != deserialized.bias.data.data_ptr() + assert torch.allclose(linear8bit.bias.data, deserialized.bias.data) + assert linear8bit.state == deserialized.state + + # check for a bug where SCB and CB were not copied + assert (linear8bit.weight.SCB == deserialized.weight.SCB).all() + assert (linear8bit.weight.CB == deserialized.weight.CB).all() diff --git a/tests/test_modules.py b/tests/test_modules.py index 8235b600c..1947ba52d 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -620,7 +620,7 @@ def test_fp8linear(): assert bgraderr < 0.00002 -def test_4bit_warnings(): +def test_4bit_warnings(requires_cuda): dim1 = 64 with pytest.warns(UserWarning, match=r"inference or training"):