diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE new file mode 100644 index 0000000000000..a46e2cdcadf7d --- /dev/null +++ b/csrc/punica/LICENSE @@ -0,0 +1,217 @@ +Contains code from https://github.com/punica-ai/punica + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache-2.0 +* third_party/nvbench (with LLVM exception) +* third_party/flashinfer + +BSD-3-Clause: +* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu new file mode 100644 index 0000000000000..bc86416701f13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -0,0 +1,21 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h new file mode 100644 index 0000000000000..3fd56b685be13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_config.h @@ -0,0 +1,53 @@ +#pragma once + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale); + +// clang-format off + +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + f(in_T, out_T, W_T, narrow, 256) \ + f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1728) \ + f(in_T, out_T, W_T, narrow, 1792) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2560) \ + f(in_T, out_T, W_T, narrow, 2752) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3456) \ + f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 6912) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 12288) \ + f(in_T, out_T, W_T, narrow, 13824) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 32000) \ + f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 49152) \ + +#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +// clang-format on diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh new file mode 100644 index 0000000000000..995de26e8bada --- /dev/null +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -0,0 +1,294 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +// nthrs = (32, 4) +template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ W_T W_shared[num_pipeline_stages * tile_size]; + __shared__ in_T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + pipe.producer_commit(); + size_t copy_idx, compute_idx; + float y = 0.f; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + } + pipe.producer_commit(); + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + + // write Y; + if (block.thread_rank() == 0) { + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); + } +} + +// nthrs = (2, 16, 4) +template +__global__ void +bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t tile_idx = blockIdx.x; + + // load X; + vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] += static_cast(sum); + } +} + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (feat_in < feat_out) { + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } else { + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } +} + +#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ + template void bgmv_kernel( \ + out_T * __restrict__ Y, const in_T *__restrict__ X, \ + const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_BGMV(narrow, wide, in_T, out_T, W_T) \ + INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh new file mode 100644 index 0000000000000..cf00d869cf635 --- /dev/null +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -0,0 +1,1324 @@ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_USE_FP8 +#include +#endif +#include + +#include + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template + FLASHINFER_INLINE void cast_load(const T *ptr); + template + FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); +}; + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, + vec_t &dst) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(const vec_t &src, + tgt_float_t *dst_ptr) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_USE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)(&dst.data))[i] = + __float22bfloat162_rn(((float2 *)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_USE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_USE_FP8 + +#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc new file mode 100644 index 0000000000000..4ad46e5e1f726 --- /dev/null +++ b/csrc/punica/punica_ops.cc @@ -0,0 +1,563 @@ +#include +#include +#include + +#include + +#include "bgmv/bgmv_config.h" + +namespace { + +//====== utils ====== + +inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, num_layers, \ + layer_idx, scale); \ + break; +#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _, _, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t num_layers = w.size(1); + int64_t full_y_size = y.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +} // namespace + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/setup.py b/setup.py index 2b040e88f0aa4..2e11119043277 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,16 @@ +import contextlib import io import os import re import subprocess -from typing import List, Set import warnings +from pathlib import Path +from typing import List, Set from packaging.version import parse, Version import setuptools import torch +import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME ROOT_DIR = os.path.dirname(__file__) @@ -31,6 +34,11 @@ "Cannot find CUDA_HOME. CUDA must be available to build the package.") +def glob(pattern: str): + root = Path(__name__).parent + return [str(p) for p in root.glob(pattern)] + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -129,19 +137,59 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") +# Use NVCC threads to parallelize the build. +if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] + +NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + if int(capability[0]) >= 8: + NVCC_FLAGS_PUNICA += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] + +# changes for punica kernels +NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS +REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', +] +for flag in REMOVE_NVCC_FLAGS: + with contextlib.suppress(ValueError): + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) ext_modules = [] + +install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1"))) +device_count = torch.cuda.device_count() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + install_punica = False + break +if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cc"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) + vllm_extension = CUDAExtension( name="vllm._C", sources=[ diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 0000000000000..263a2bc9d8156 --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,139 @@ +import gc +import tempfile +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +import pytest +import ray +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import vllm +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) + + +def cleanup(): + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_fixture(): + yield + cleanup() + + +@pytest.fixture +def dist_init(): + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(1, 1) + yield + cleanup() + + +@pytest.fixture +def dist_init_torch_only(): + if torch.distributed.is_initialized(): + return + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + + +@pytest.fixture +def dummy_model() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture +def dummy_model_gate_up() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture(scope="session") +def sql_lora_files(): + return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +@pytest.fixture +def llama_2_7b_engine_extra_embeddings() -> nn.Module: + cleanup() + get_model_old = get_model + + def get_model_patched(model_config, lora_config=None): + return get_model_old(model_config, LoRAConfig(max_lora_rank=8)) + + with patch("vllm.worker.worker.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine + del engine + cleanup() + + +@pytest.fixture +def llama_2_7b_model_extra_embeddings( + llama_2_7b_engine_extra_embeddings) -> nn.Module: + yield llama_2_7b_engine_extra_embeddings.workers[0].model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 0000000000000..fa6a18e8d93d2 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,697 @@ +import pytest +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn.functional as F + +from vllm.lora.layers import ( + LoRAColumnParallelLinear, + LoRAMergedColumnParallelLinear2Slice, + LoRAQKVParallelLinear, + LoRAVocabParallelEmbedding, + LoRARowParallelLinear, + LoRASampler, + LoRAMapping, + LoRALayer, +) +from vllm.lora.models import LoRA, convert_mapping +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.utils import set_random_seed + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: LoRALayer, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRA], Dict[int, List[LoRA]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRA] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. Only useful when + # repeats > 1. + sublora_dict: Dict[int, List[LoRA]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager().init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = LoRA.pack(subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs, index_mapping, prompt_mapping = [], [], [] + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device="cuda")) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device="cuda") * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[512:, :] = 0 + lora_embedding = LoRAVocabParallelEmbedding(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[512:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + 512 + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=512) + expanded_embedding.weight.data[:512, :] = embedding_data + # We need to deepcopy the embedding as it will be modifed + # in place + lora_embedding = LoRAVocabParallelEmbedding( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, 512 + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append( + torch.zeros(embeddings_tensors[0].shape, device="cuda")) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = 512 + (embedding_id * embeddings_tensor_len) + original_input_[-1] = 512 + input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = 512 + embeddings_tensor_len - 1 + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + expanded_embedding.weight[512:512 + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lm_head_sampler(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_sampler_layer(): + linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, + 1024, 32000) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, 32000:] = 0 + sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_sampler = LoRASampler(sampler, 1024, linear.weight.dtype, + linear.weight.device) + lora_sampler.create_lora_weights(max_loras, lora_config) + + return linear, sampler, lora_sampler + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, sampler, lora_sampler = create_random_sampler_layer() + + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_sampler, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + input_ = torch.rand(20, 1024, device="cuda") + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 32000, + lora_config.lora_extra_vocab_size, + ) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) + + original_weight = linear.weight.clone() + + linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = sampler._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) + result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + sampler.org_vocab_size = 32000 + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_sampler.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 32000, + lora_config.lora_extra_vocab_size) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +def test_linear_parallel(dist_init, num_loras, orientation) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRARowParallelLinear(linear) + else: + linear = ColumnParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAColumnParallelLinear(linear) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [2, 3]) +def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAMergedColumnParallelLinear2Slice(linear) + else: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAQKVParallelLinear(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( + i + 1 + )] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py new file mode 100644 index 0000000000000..756fc55246092 --- /dev/null +++ b/tests/lora/test_llama.py @@ -0,0 +1,141 @@ +import pytest +import ray +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_llama_lora(sql_lora_files, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=tp_size, + worker_use_ray=True) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + + +def test_llama_tensor_parallel_equality(sql_lora_files): + if torch.cuda.device_count() < 4: + pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=1, + worker_use_ray=True) + output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) + + del llm_tp1 + ray.shutdown() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=2, + worker_use_ray=True) + output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) + + del llm_tp2 + ray.shutdown() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=4, + worker_use_ray=True) + output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) + + del llm_tp4 + ray.shutdown() + + assert output_tp1 == output_tp4 + + +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more" + " conservative than without lora, therefore the number of memory blocks for the KV cache should be " + "less when using lora than when not using lora") diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py new file mode 100644 index 0000000000000..b86f7a480e749 --- /dev/null +++ b/tests/lora/test_lora.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_2slice, _apply_lora_packed_3slice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="cuda", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="cuda", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="cuda"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora_packed_2slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, m // 2) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_2slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, m // 2) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + _apply_lora_packed_3slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (qkv[0], qkv[1])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_3slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (qkv[0], qkv[1])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py new file mode 100644 index 0000000000000..de7b245ad4e79 --- /dev/null +++ b/tests/lora/test_lora_manager.py @@ -0,0 +1,473 @@ +import os +from typing import List + +import pytest +import torch +from safetensors.torch import load_file +from torch import nn + +from vllm.config import LoRAConfig +from vllm.lora.layers import (LoRAColumnParallelLinear, LoRARowParallelLinear, + LoRAMergedColumnParallelLinear2Slice) +from vllm.lora.lora import LoRA, PackedLoRA +from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, LoRAMapping) +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, + WorkerLoRAManager) +from vllm.model_executor.layers.linear import RowParallelLinear + + +def test_from_lora_tensors(sql_lora_files): + tensors = load_file( + os.path.join(sql_lora_files, "adapter_model.safetensors")) + new_embeddings = load_file( + os.path.join(sql_lora_files, "new_embeddings.safetensors")) + lora_model = LoRAModel.from_lora_tensors(1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings) + for module_name, lora in lora_model.loras.items(): + assert lora.module_name == module_name + assert lora.rank == 8 + assert lora.lora_alpha == 16 + assert lora.lora_a is not None + assert lora.lora_b is not None + assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == 8 + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), None) + if embeddings_module: + assert torch.equal( + lora.embeddings_tensor, + new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( + device=lora.embeddings_tensor.device)) + else: + assert lora.embeddings_tensor is None + + +def create_lora(lora_id: int, model: nn.Module, + sub_modules: List[str]) -> LoRAModel: + loras = {} + for name in sub_modules: + w = model.get_submodule(name).weight + loras[name] = LoRA( + name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0]], device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def create_packed_lora( + lora_id: int, + model: nn.Module, + module_name, + replaced_module_names, + empty_replaced_module_name=None, +) -> LoRAModel: + w = model.get_submodule(module_name).weight + loras = {} + for replaced_module_name in replaced_module_names: + if replaced_module_name == empty_replaced_module_name: + continue + loras[replaced_module_name] = LoRA( + replaced_module_name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0] // len(replaced_module_names)], + device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def test_replace_submodules(dist_init, dummy_model): + model = dummy_model + manager = LoRAModelManager(model, + 1, + 1, + 1, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8), + lora_target_modules=["dense1", "layer1.dense2"]) + model = manager.model + + assert isinstance(model.get_submodule("dense1"), LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("layer1.dense1"), + LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("layer1.dense2"), + LoRARowParallelLinear) + + +def test_lora_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + with pytest.raises(ValueError): + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] is None + assert manager.add_lora(model_lora2) + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] is None + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + +def test_lora_lru_cache_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 1 + assert manager.add_lora(model_lora2) + assert manager.deactivate_lora(3) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 3 + + +def test_lru_lora_model_manager(dist_init, dummy_model): + # This tests just the LRU cache functionality, everything else is + # tested in test_lora_model_manager + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["dense1", "dense2", "lm_head"]) + + assert all(x is None for x in manager.lora_id_to_index) + + # Add up to capacity + assert manager.add_lora(model_lora1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(1) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + + # Add over capacity + assert manager.add_lora(model_lora3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(3) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + # Add 3 again to move it to the top and then add 2 + # should return false since it's in already + assert not manager.add_lora(model_lora3) + assert not manager.activate_lora(3) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {3, 2} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {2} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 2 + + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {4} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + assert not manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + +def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = LRUCacheWorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 7 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 6 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + # Should remove every LoRA not specified in the request. + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = WorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None + assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 6 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 7 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_packed_loras(dist_init, dummy_model_gate_up): + model = dummy_model_gate_up + model_lora = create_packed_lora( + 1, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"]) + model_lora1 = create_packed_lora( + 2, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"], + empty_replaced_module_name="gate_proj", + ) + + manager = LoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["gate_up_proj"]) + model = manager.model + + assert isinstance(model.get_submodule("gate_up_proj"), + LoRAMergedColumnParallelLinear2Slice) + assert manager.add_lora(model_lora) + assert manager.add_lora(model_lora1) + + packed_lora = model_lora.get_lora("gate_up_proj") + assert packed_lora and isinstance(packed_lora, PackedLoRA) + + assert torch.allclose(packed_lora.lora_a[0], + model_lora.get_lora("gate_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[0], + model_lora.get_lora("gate_proj").lora_b) + assert torch.allclose(packed_lora.lora_a[1], + model_lora.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[1], + model_lora.get_lora("up_proj").lora_b) + + packed_lora1 = model_lora1.get_lora("gate_up_proj") + assert packed_lora1 and isinstance(packed_lora1, PackedLoRA) + + assert packed_lora1.lora_a[0] is None + assert packed_lora1.lora_b[0] is None + assert torch.allclose(packed_lora1.lora_a[1], + model_lora1.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora1.lora_b[1], + model_lora1.get_lora("up_proj").lora_b) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py new file mode 100644 index 0000000000000..26a7d47933309 --- /dev/null +++ b/tests/lora/test_punica.py @@ -0,0 +1,196 @@ +# Based on code from https://github.com/punica-ai/punica + +import pytest +import torch + +import vllm.lora.punica as punica + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), + torch.float32: (None, None), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def _lora_ref_impl( + y_final: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + y_stage_1 = torch.empty( + (x.size(0), wa_T_all.size(-2)), + dtype=torch.float32, + device=x.device, + ) + bs = x.shape[0] + s = torch.tensor(scale, dtype=torch.float32, device=x.device) + for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): + xi = x[i].unsqueeze(0).to(torch.float32) + wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + + tmp = xi @ wa + y_stage_1[i] = tmp.squeeze(0) + y_final[i] += (tmp @ wb).squeeze(0) * s + return y_final, y_stage_1 + + +H1 = H2 = [ + 128, + 256, + 512, + 1024, + 1280, + 2048, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 5120, + 5504, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 32000, + 32256, +] +SEED = [0xabcdabcd987] + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness(dtype_str, h1, h2, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all = torch.randn(num_loras, + num_layers, + h2, + r, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) + + y_our = y.clone() + punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, + scale) + + assert_close(y_ref, y_our) + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness_slice(dtype_str, h1, h2, seed): + if h2 % 3 != 0 or h2 // 3 not in H1: + pytest.skip("h2 must be divisible by 3 and in supported shapes") + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all_0 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_1 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_2 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all_0 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_1 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_2 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + s = h2 // 3 + + y_ref = y.clone() + _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale) + + y_our = y.clone() + punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale, 0, s) + punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale, s, s) + punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale, s * 2, s) + + assert_close(y_ref[:, :s], y_our[:, :s]) + assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) + assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py new file mode 100644 index 0000000000000..af0fc41f3fa45 --- /dev/null +++ b/tests/lora/test_tokenizer.py @@ -0,0 +1,69 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import MultiLoRATokenizer, get_lora_tokenizer + + +@pytest.mark.asyncio +async def test_transformers_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=None) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +async def test_transformers_tokenizer_lora(sql_lora_files): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + lora_request) != tokenizer.get_lora_tokenizer(None) + assert tokenizer.get_lora_tokenizer( + lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py new file mode 100644 index 0000000000000..a874a72d919fa --- /dev/null +++ b/tests/lora/test_utils.py @@ -0,0 +1,172 @@ +from collections import OrderedDict + +from torch import nn + +from vllm.lora.utils import (LRUCache, parse_fine_tuned_lora_name, + replace_submodule) + + +def test_parse_fine_tuned_lora_name(): + fixture = { + ("base_model.model.lm_head.lora_A.weight", "lm_head", True), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ( + "base_model.model.model.embed_tokens.lora_embedding_A", + "model.embed_tokens", + True, + ), + ( + "base_model.model.model.embed_tokens.lora_embedding_B", + "model.embed_tokens", + False, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "model.layers.9.mlp.down_proj", + True, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "model.layers.9.mlp.down_proj", + False, + ), + } + for name, module_name, is_lora_a in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + + +def test_replace_submodule(): + model = nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ])) + + sigmoid = nn.Sigmoid() + + replace_submodule(model, "act1", sigmoid) + assert dict(model.named_modules())["act1"] == sigmoid + + dense2 = nn.Linear(1, 5) + replace_submodule(model, "seq1.dense2", dense2) + assert dict(model.named_modules())["seq1.dense2"] == dense2 + + +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache.get(2) == 2 + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py new file mode 100644 index 0000000000000..8c11f6c472ff7 --- /dev/null +++ b/tests/lora/test_worker.py @@ -0,0 +1,56 @@ +import os +import random +import tempfile +from unittest.mock import patch + +from vllm.lora.models import LoRAMapping +from vllm.lora.utils import LoRARequest +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.worker.worker import Worker + + +@patch.dict(os.environ, {"RANK": "0"}) +def test_worker_apply_lora(sql_lora_files): + worker = Worker( + model_config=ModelConfig("meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None), + parallel_config=ParallelConfig(1, 1, False), + scheduler_config=SchedulerConfig(32, 32, 32, 256), + lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, + max_loras=32), + distributed_init_method=f"file://{tempfile.mkstemp()[1]}", + ) + worker.init_model() + worker.load_model() + + worker.apply_loras([], LoRAMapping([], [])) + assert worker.list_loras() == set() + + n_loras = 32 + lora_requests = [ + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) + ] + + worker.apply_loras(lora_requests, LoRAMapping([], [])) + assert worker.list_loras() == { + lora_request.lora_int_id + for lora_request in lora_requests + } + + for i in range(32): + random.seed(i) + iter_lora_requests = random.choices(lora_requests, + k=random.randint(1, n_loras)) + random.shuffle(iter_lora_requests) + iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] + worker.apply_loras(iter_lora_requests, LoRAMapping([], [])) + assert worker.list_loras().issuperset( + {lora_request.lora_int_id + for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 0000000000000..072a0d957758b --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +import torch + +from vllm.lora.lora import LoRA + + +class DummyLoRAManager: + + def __init__(self): + super().__init__() + self._loras = {} + + def set_module_lora(self, module_name: str, lora: LoRA): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> Optional[LoRA]: + return self._loras.get(module_name, None) + + def init_random_lora(self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device="cuda"), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device="cuda"), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand(5, + generate_embeddings_tensor, + dtype=weight.dtype, + device="cuda") + self.set_module_lora(module_name, lora) + + return lora + + def init_lora(self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: List[int], + noop_lora_index: List[int] = None, + rank=8, + ): + base_loras = [] + noop_lora_index = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index, + ) + base_loras.append(base_lora) + packed_lora = LoRA.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/vllm/config.py b/vllm/config.py index 1adf830ffcc12..eef6e53be2855 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ from typing import Optional, Union +from dataclasses import dataclass import os import torch @@ -350,6 +351,36 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +@dataclass +class LoRAConfig: + max_lora_rank: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + max_loras: Optional[int] = None + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + self.max_loras = scheduler_config.max_num_seqs + if self.max_cpu_loras is None: + self.max_cpu_loras = scheduler_config.max_num_seqs + elif self.max_cpu_loras < scheduler_config.max_num_seqs: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({scheduler_config.max_num_seqs})") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb95..f8fb4c6ea1518 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,10 +1,11 @@ import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union, Set from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -36,6 +37,7 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -47,11 +49,23 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + if lora_enabled: + self.num_loras = len(set(self.lora_requests)) + self._sort_by_lora_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups.sort(key=lambda g: ( + g.lora_request.lora_int_id if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + class Scheduler: @@ -59,9 +73,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_enabled: bool = False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + self.lora_enabled = lora_enabled self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -202,6 +218,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -274,6 +291,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -299,6 +317,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + lora_request=seq_group.lora_request, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 746b0e64ece7b..4d1233c473980 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) @dataclass @@ -33,6 +33,11 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + enable_lora: bool = False + max_lora_rank: int = 8 + lora_extra_vocab_size: int = 256 + lora_dtype = 'bfloat16' + lora_max_cpu_loras: int = -1 def __post_init__(self): if self.tokenizer is None: @@ -182,6 +187,30 @@ def add_cli_args( choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='enable lora adapters') + parser.add_argument('--max-lora-rank', + type=int, + default=16, + help='max LoRA rank') + parser.add_argument('--lora-extra-vocab-size', + type=int, + default=256, + help='LoRA extra vocab size') + parser.add_argument('--lora-dtype', + type=str, + default=EngineArgs.dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help='data type for lora') + parser.add_argument( + '--lora-max-cpu-loras', + type=int, + default=-1, + help=('Maximum number of loras to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -194,7 +223,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + Optional[LoRAConfig]]: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -212,7 +242,13 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > + 0 else None) if self.enable_lora else None + return model_config, cache_config, parallel_config, scheduler_config, lora_config @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7dcd2eb632c4c..53bc7080b3273 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) +from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine @@ -198,6 +199,50 @@ async def step_async(self) -> List[RequestOutput]: return self._process_model_outputs(output, scheduler_outputs) + ignored + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + async def _run_workers_async( self, method: str, @@ -318,7 +363,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + await self.engine.add_request_async(**new_request) if finished_requests: await self._engine_abort(finished_requests) @@ -357,6 +402,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -370,7 +416,8 @@ async def add_request( logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -382,12 +429,22 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + lora_request=lora_request, + ) return stream @@ -396,7 +453,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None) -> RequestOutput: + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None) -> RequestOutput: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -410,6 +468,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -420,11 +479,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + ) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e33d8aa2a2131..c6e74b1d26586 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,8 +3,9 @@ from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray @@ -15,7 +16,7 @@ SequenceGroupMetadata, SequenceGroupOutputs, SequenceOutputs, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + MultiLoRATokenizer) from vllm.utils import Counter if ray: @@ -65,6 +66,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -90,17 +92,13 @@ def __init__( self.cache_config = cache_config assert self.cache_config.sliding_window == getattr( self.model_config.hf_config, "sliding_window", None) + self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + self._init_tokenizer() self.seq_counter = Counter() # Create the parallel GPU workers. @@ -137,6 +135,7 @@ def _init_workers(self, distributed_init_method: str): self.scheduler_config, 0, distributed_init_method, + lora_config=self.lora_config, ) self.workers.append(worker) self._run_workers( @@ -150,6 +149,18 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) + def _init_tokenizer(self, **kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(kwargs) + self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( + self.model_config.tokenizer, **init_kwargs) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -183,6 +194,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config, None, None, + lora_config=self.lora_config, )) self._run_workers( "init_model", @@ -198,6 +210,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache.""" @@ -246,6 +262,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -253,6 +283,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -270,20 +301,26 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -648,7 +685,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer, + self.tokenizer.get_lora_tokenizer(seq.lora_request), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -689,11 +726,29 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.tokenizer.get_lora_tokenizer( + seq.lora_request).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + def _run_workers_in_batch( self, workers, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b05ba71c6d352..9061909d72c33 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.lora.request import LoRARequest from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -109,6 +110,7 @@ def generate( sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -123,6 +125,7 @@ def generate( prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. Returns: A list of `RequestOutput` objects containing the generated @@ -149,7 +152,10 @@ def generate( prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request) return self._run_engine(use_tqdm) def _add_request( @@ -157,10 +163,14 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000000000..6ba8b0585847d --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1002 @@ +# pylint: disable=unused-argument +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_2slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_dim: int, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 2 sublayers + (slices) packed together (eg. gate_proj + up_proj -> + gate_up_proj). + + Both slices must have the same size (output_dim), meaning the output + tensor will have size output_dim*2. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 2 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 2 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim*2) + output_dim: scalar + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_dim) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_dim, output_dim) + return output.view_as(org_output) + + +def _apply_lora_packed_3slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, int], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 3 sublayers + (slices) packed together (attention projection). The + first slice (Q) may have different size from the two subsequent + slices (K, V). + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: 2 element tuple of (q_slice_size, kv_slice_size) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_slices[0]) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_slices[0], output_slices[1]) + add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, + 1.0, output_slices[0] + output_slices[1], output_slices[1]) + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + index_mapping: Tuple[int, ...] + prompt_mapping: Tuple[int, ...] + + def __eq__(self, __value: object) -> bool: + return (isinstance(__value, self.__class__) + and self.prompt_mapping == __value.prompt_mapping + and self.index_mapping == __value.index_mapping) + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class LoRALayer(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class LoRAVocabParallelEmbedding(LoRALayer): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + self.embeddings_weights.copy_( + self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2]) + [self.embeddings_slice[0]:self.embeddings_slice[1]]) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class LoRAColumnParallelLinear(LoRALayer): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class LoRAMergedColumnParallelLinear2Slice(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_2slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_dim, + ) + return output + + +class LoRAQKVParallelLinear(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = (torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + self.lora_b_stacked = (torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_3slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class LoRARowParallelLinear(LoRALayer): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class LoRASampler(LoRALayer): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + + logits[:, self.base_layer.org_vocab_size:] = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> LoRALayer: + supported_layer_types = { + VocabParallelEmbedding: LoRAVocabParallelEmbedding, + ColumnParallelLinear: LoRAColumnParallelLinear, + QKVParallelLinear: LoRAQKVParallelLinear, + MergedColumnParallelLinear: LoRAMergedColumnParallelLinear2Slice, + RowParallelLinear: LoRARowParallelLinear, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LoRASampler: + ret = LoRASampler(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000000000..042a98597ab26 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,120 @@ +from typing import List, Optional + +import torch + + +class LoRA: + """A LoRA that is composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + @classmethod + def pack(cls, loras: List["LoRA"]) -> "PackedLoRA": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = PackedLoRA( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "LoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + +class PackedLoRA(LoRA): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + def optimize(self) -> "PackedLoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000000000..913234475b182 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,666 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache + +from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRA +from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule) + +logger = logging.getLogger(__name__) + +PACKED_MODULES_CFG = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], +} + +TARGET_MODULES_QKV = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", +] + +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_id_to_index: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_id_to_index: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_id_to_index.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_id_to_index.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +def _create_dummy_lora(module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRA": + lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device) + lora_b = torch.zeros([rank, output_dim], dtype=dtype, device=device) + embeddings_tensor = torch.rand( + 10, embeddings_tensor_dim, dtype=dtype, + device=device) if embeddings_tensor_dim else None + if str(device) == "cpu": + lora_a = lora_a.pin_memory() + lora_b = lora_b.pin_memory() + if embeddings_tensor is not None: + embeddings_tensor = embeddings_tensor.pin_memory() + return LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRA], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRA] = loras + + def get_lora(self, module_name: str) -> Optional[LoRA]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + loras: Dict[str, LoRA] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + EMBEDDING_MODULES[embeddings_module]].to( + device=device, dtype=dtype) + if device == "cpu": + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRA(module_name, rank, lora_alpha, None, + None, lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if device == "cpu": + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in EMBEDDING_PADDING_MODULES + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if device == "cpu": + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for _, lora in loras.items(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + lora_target_modules: the target modules patterns to be adapted. + Support both single module name and a list of module names. + packed_modules_mapping: the mapping for packed modules. vLLM + packs some modules into one module, e.g., qkv_proj + is packed of q_proj, k_proj, and v_proj. These modules + have a single layer in the original model, but they are split + into multiple layers in the adapted model. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_id_to_index: List[Optional[int]] = [None] * self._lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + self.indices_len = [None] * 4 + + self.model: nn.Module = model + self.lora_target_modules: List[str] = ([ + lora_target_modules + ] if isinstance(lora_target_modules, str) else lora_target_modules) + self.lora_target_modules = copy.deepcopy(lora_target_modules) + self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "LoRALayer"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def _lora_slots(self) -> int: + return self.max_num_seqs + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_id_to_index) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_id_to_index[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_id_to_index.index(lora_id) + self.lora_id_to_index[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def convert_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_id_to_index, + self._lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_row_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self.convert_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_id_to_index = [None] * self._lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.capacity, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.capacity, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "LoRALayer"): + assert isinstance(module, LoRALayer) + self.modules[module_name] = module + + def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, LoRALayer): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in EMBEDDING_MODULES: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = _create_dummy_lora( + module_name, + input_dim, + output_dim, + rank, + module.base_layer.weight.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = _create_dummy_lora( + module_name, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0], + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for r in replacements: + lora = _create_dummy_lora( + module_name + "." + r, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0] // len(replacements), + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = LoRA.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.lora_target_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = LoRA.pack(replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, lora_target_modules, + packed_modules_mapping) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.max_num_seqs, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + was_added = False + if lora.id not in self._registered_loras: + was_added = True + logger.debug(f"Adding LoRA. Model id: {lora.id}, " + f"int id: {lora.id}") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.max_num_seqs: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_adapter( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config:LoRAConfig, + target_modules: Union[str, + List[str]] = TARGET_MODULES_QKV, + lora_manager_cls:Type[LoRAModelManager] = LoRAModelManager, **kwargs)\ + -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not getattr(model, "supports_lora", False): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + lora_target_modules=target_modules, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000000000..ac96931b2d071 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,173 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +import_exc = None + +try: + import vllm._punica_C as punica_kernels +except ImportError as e: + import_exc = e + +if import_exc is None: + + def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical innacuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) + +else: + + def _raise_exc( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "LoRA kernels require compute capability>=8.0") from import_exc + else: + raise import_exc + + bgmv = _raise_exc + add_lora = _raise_exc + add_lora_slice = _raise_exc + +__all__ = [ + "bgmv", + "add_lora", + "add_lora_slice", +] diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000000000..3ae5be59b1b88 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + lora_id: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance(value, LoRARequest) and self.lora_id == value.lora_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000000000..f67a3812fb046 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000..be6f4cf0589bd --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,266 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Union + +import torch + +from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_adapter) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class AbstractWorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + ... + + @abstractmethod + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class DisabledWorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that does nothing.""" + + @property + def is_enabled(self) -> bool: + return False + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + return model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + return + + def add_lora(self, lora_request: LoRARequest) -> bool: + return False + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + return False + + def remove_lora(self, lora_id: int) -> bool: + return False + + def remove_all_loras(self) -> bool: + return + + def list_loras(self) -> Set[int]: + return set() + + +class WorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + target_modules=target_modules, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_row_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + target_modules=target_modules, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c874ec5921155..5bce287a92ae5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -29,9 +29,24 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, @@ -44,8 +59,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(hidden_states, embedding, embedding_bias) # Apply logits processors (if any). logits = _apply_logits_processors(logits, input_metadata) @@ -97,19 +111,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] - return logits - - def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..9e4ac26e73d00 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,16 +43,19 @@ class VocabParallelEmbedding(torch.nn.Module): num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings + self.org_vocab_size = org_num_embeddings or num_embeddings self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.embedding_dim = embedding_dim if params_dtype is None: @@ -77,7 +80,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.num_embeddings + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -114,14 +117,17 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None): - super().__init__(num_embeddings, embedding_dim, params_dtype) + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 54b87c4b866e3..cf84b9810c575 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,12 +1,12 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Optional, Type import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import * from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -58,7 +58,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -87,7 +88,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config, linear_method) + # TODO(yard1): Clean this up (lora_config) + try: + model = model_class(model_config.hf_config, linear_method, + lora_config) + except TypeError: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8e7344da4888e..999c1097d0a42 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -223,14 +224,19 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, linear_method) @@ -264,18 +270,25 @@ def forward( class LlamaForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d18572610741c..c67c3fae2028a 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -217,15 +218,20 @@ def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, linear_method) @@ -259,18 +265,27 @@ def forward( class MistralForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = MistralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = MistralModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index fe54926e06e64..534e9d5ea8a53 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,7 @@ from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) +from vllm.lora.request import LoRARequest class CompletionOutput: @@ -16,6 +17,7 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -26,6 +28,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index self.text = text @@ -33,6 +36,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.lora_request = lora_request def finished(self) -> bool: return self.finish_reason is not None @@ -56,6 +60,7 @@ class RequestOutput: prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -66,6 +71,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -73,6 +79,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished + self.lora_request = lora_request @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -108,8 +115,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + lora_request=seq_group.lora_request) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -117,4 +129,5 @@ def __repr__(self) -> str: f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " - f"finished={self.finished})") + f"finished={self.finished}, " + f"lora_request={self.lora_request})") diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..06170ab79d69a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] @@ -105,6 +106,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + lora_request: LoRA request. """ def __init__( @@ -113,10 +115,12 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -228,6 +232,7 @@ class SequenceGroup: seqs: The list of sequences. sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. + lora_request: LoRA request. """ def __init__( @@ -236,11 +241,13 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -335,6 +342,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + lora_request: LoRA request. """ def __init__( @@ -344,12 +352,18 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.lora_request = lora_request + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 class SequenceOutputs: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 5b0481480a63b..b84f50c3bd5d7 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -4,6 +4,8 @@ PreTrainedTokenizerFast) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache logger = init_logger(__name__) @@ -69,6 +71,86 @@ def get_tokenizer( return tokenizer +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class MultiLoRATokenizer: + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self): + return True + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/utils.py b/vllm/utils.py index 47e51048fed45..9282db842c1d2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,9 +4,20 @@ import psutil import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional from vllm._C import cuda_utils +T = TypeVar("T") + class Device(enum.Enum): GPU = enum.auto() @@ -27,6 +38,69 @@ def reset(self) -> None: self.counter = 0 +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html @@ -53,3 +127,19 @@ def random_uuid() -> str: def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 702767ebd8d09..d316b9588bf75 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,12 +1,13 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) @@ -14,6 +15,14 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import ( + DisabledWorkerLoRAManager, + LRUCacheWorkerLoRAManager, +) +from vllm.lora.layers import LoRAMapping + +LORA_WARMUP_RANK = 8 class Worker: @@ -31,12 +40,14 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_config = lora_config # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). @@ -46,6 +57,7 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None + self.lora_manager = None def init_model(self): # This env var set by Ray causes exceptions with graph building. @@ -69,7 +81,21 @@ def init_model(self): set_random_seed(self.model_config.seed) def load_model(self): - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_adapter(self.model) + else: + self.lora_manager = DisabledWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) @torch.inference_mode() def profile_num_available_blocks( @@ -91,6 +117,24 @@ def profile_num_available_blocks( sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + if self.lora_config: + for idx in range(max_num_seqs): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_id=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + seqs = [] for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + @@ -102,11 +146,21 @@ def profile_num_available_blocks( seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + lora_request=dummy_lora_requests[group_id] + if dummy_lora_requests else None, ) seqs.append(seq) - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seqs) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + prepared_lora_requests, + ) = self._prepare_inputs(seqs) + + if dummy_lora_requests: + self.apply_loras(prepared_lora_requests, lora_mapping) # Execute the model. num_layers = self.model_config.get_num_layers(self.parallel_config) @@ -131,6 +185,8 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.lora_manager.remove_all_loras() + gc.collect() torch.cuda.empty_cache() # Reset the seed to ensure that the random state is not affected by @@ -151,7 +207,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: def _prepare_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, LoRAMapping, + Set[LoRARequest]]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -160,6 +217,9 @@ def _prepare_inputs( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + lora_requests: Set[LoRARequest] = set() + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] # Add prompt tokens. prompt_lens: List[int] = [] @@ -170,6 +230,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id # Use any sequence in the group. seq_id = seq_ids[0] @@ -187,6 +248,17 @@ def _prepare_inputs( categorized_sample_indices_start_idx) categorized_sample_indices_start_idx += 1 + if lora_id > 0: + # if we are preparing inputs for the warmup step, we want the + # lora computation to take up the maximum possible amount of + # memory that way we can get a tighter upper bound on the + # amount of memory we can use and therefore not oom. If + # for_warmup is true, we add the lora lora mapping that is used + # during generation. + lora_requests.add(seq_group_metadata.lora_request) + lora_index_mapping.append([lora_id] * prompt_len) + lora_prompt_mapping.append(lora_id) + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -233,6 +305,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id num_seqs = len(seq_ids) selected_token_indices.extend( @@ -255,6 +328,7 @@ def _prepare_inputs( if self.sliding_window is not None: context_len = min(context_len, self.sliding_window) input_positions.append([position]) + lora_index_mapping.append([lora_id]) block_table = seq_group_metadata.block_tables[seq_id] @@ -274,6 +348,11 @@ def _prepare_inputs( block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) + # Update LoRA mapping. + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + lora_prompt_mapping.append(lora_id) + padded_input_tokens = [ _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens ] @@ -281,6 +360,10 @@ def _prepare_inputs( _pad_to_max(positions, max_seq_len, pad=0) for positions in input_positions ] + padded_lora_input_mapping = [ + _pad_to_max(mapping, max_seq_len, pad=0) + for mapping in lora_index_mapping + ] padded_slot_mapping = [ _pad_to_max(mapping, max_seq_len, pad=-1) for mapping in slot_mapping @@ -318,6 +401,14 @@ def _prepare_inputs( for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) + flat_padded_lora_input_mapping = [ + item for sublist in padded_lora_input_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_padded_lora_input_mapping, + lora_prompt_mapping, + ) + input_metadata = InputMetadata( seq_groups=seq_groups, seq_data=seq_data, @@ -330,7 +421,7 @@ def _prepare_inputs( categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, ) - return tokens_tensor, positions_tensor, input_metadata + return tokens_tensor, positions_tensor, input_metadata, lora_mapping, lora_requests @torch.inference_mode() def execute_model( @@ -362,8 +453,20 @@ def execute_model( return {} # Prepare input tensors. - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seq_group_metadata_list) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + lora_requests, + ) = self._prepare_inputs(seq_group_metadata_list) + + if self.lora_config: + lora_requests = [ + seq_group_metadata.lora_request + for seq_group_metadata in seq_group_metadata_list + ] + self.apply_loras(lora_requests, lora_mapping) # Execute the model. output = self.model( @@ -375,6 +478,19 @@ def execute_model( ) return output + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.apply_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.lora_manager.list_loras() + def _init_distributed_environment( parallel_config: ParallelConfig,