diff --git a/setup.py b/setup.py index ed4b88364a6d..34be6a6412a7 100755 --- a/setup.py +++ b/setup.py @@ -280,6 +280,14 @@ def run(self): print(f"Copying {file} to {dst_file}") self.copy_file(file, dst_file) + # copy these folders to use the vllm_flash_attn rotary_kernel. + for folder in ("layers", "ops"): + src = os.path.join(self.build_lib, "vllm", "vllm_flash_attn", + folder) + out = os.path.join("vllm", "vllm_flash_attn", folder) + print(f"Copying {folder} from vllm/vllm_flash_attn") + self.copy_tree(src, out) + class repackage_wheel(build_ext): """Extracts libraries and other files from an existing wheel.""" @@ -400,6 +408,31 @@ def run(self) -> None: package_data[package_name].append(file_name) + # Extract and include the layers and ops of rotary embedding. + folders_to_copy = {"layers", "ops"} + for folder in folders_to_copy: + folder_path = f"vllm/vllm_flash_attn/{folder}" + folder_files = [ + f for f in wheel.filelist + if f.filename.startswith(folder_path) + ] + + if folder_files: + print(f"Include {folder} folder from vllm/vllm_flash_attn") + for file in folder_files: + wheel.extract(file) + + # Add the file to package_data if it's not a Python file + rel_path = file.filename.split("/") + # vllm/vllm_flash_attn/folder/file + if len(rel_path) >= 4: + package_name = "vllm.vllm_flash_attn." + folder + file_name = rel_path[-1] + if not file_name.endswith(".py"): + if package_name not in package_data: + package_data[package_name] = [] + package_data[package_name].append(file_name) + def _is_hpu() -> bool: # if VLLM_TARGET_DEVICE env var was set explicitly, skip HPU autodetection