diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 9a49ed2a3e61..f9c4e3c003c9 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -65,9 +65,9 @@ jobs: run: | source ${GITHUB_WORKSPACE}/venv/bin/activate cd jax - python build/build.py \ - --bazel_options=--color=yes \ - --bazel_options=--copt=-fsanitize=address \ + python build/build.py jaxlib --verbose \ + --bazel_build_options='--verbose_failures=true' \ + --bazel_build_options='--copt=-fsanitize=address' \ --clang_path=/usr/bin/clang-18 pip install dist/jaxlib-*.whl pip install -e . diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 2b4a616e224a..30740c054951 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -40,10 +40,9 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang ` - --verbose + python.exe build\build.py jaxlib --verbose ` + --bazel_build_options='--verbose_failures=true' ` + --bazel_build_options='--config=win_clang' - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 3173b81e6819..6e273e3b03ce 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -49,9 +49,9 @@ jobs: python -m pip install -r build/test-requirements.txt python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH - python.exe build\build.py ` - --bazel_options=--color=yes ` - --bazel_options=--config=win_clang + python.exe build\build.py jaxlib --verbose ` + --bazel_build_options='--verbose_failures=true' ` + --bazel_build_options='--config=win_clang' - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: diff --git a/.github/workflows/windows_presubmit.yml b/.github/workflows/windows_presubmit.yml new file mode 100644 index 000000000000..bbdbc68e4fad --- /dev/null +++ b/.github/workflows/windows_presubmit.yml @@ -0,0 +1,70 @@ +name: Presubmit - Windows CPU +on: + # TODO(DO_NOT_SUBMIT): temporary check + push: + branches: + - main + pull_request: + branches: + - main + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + +env: + DISTUTILS_USE_SDK: 1 + MSSdk: 1 + +jobs: + presubmit-win-wheels: + if: ${{ (github.event.action != 'labeled') || (github.event.label.name == 'windows:force-run')}} + strategy: + fail-fast: true + matrix: + os: [windows-2019-32core] + arch: [AMD64] + pyver: ['3.10'] + name: ${{ matrix.os }} Windows build + runs-on: ${{ matrix.os }} + + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + + - name: Install LLVM/Clang + run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.pyver }} + cache: 'pip' + + - name: Build wheels + env: + BAZEL_VC: "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC" + JAXLIB_RELEASE: true + run: | + python -m pip install -r build/test-requirements.txt + python -m pip install --upgrade numpy==2.0.0 scipy==1.13.1 + "C:\\msys64\\;C:\\msys64\\usr\\bin\\;" >> $env:GITHUB_PATH + python.exe build\build.py jaxlib --verbose ` + --bazel_build_options='--verbose_failures=true' ` + --bazel_build_options="--config=win_clang" ` + --clang_path="C:\Program Files\LLVM\bin\clang.exe" ` + --bazel_build_options='--color=yes' + + - name: Run tests + env: + JAX_ENABLE_CHECKS: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + python -m pip install --find-links ${{ github.workspace }}\dist jaxlib + python -m pip install -e ${{ github.workspace }} + echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" + pytest -n auto --tb=short tests examples \ No newline at end of file diff --git a/build/build.py b/build/build.py index 62e4217c10a2..9a0b9328c6dc 100755 --- a/build/build.py +++ b/build/build.py @@ -14,94 +14,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Helper script for building JAX's libjax easily. +# CLI for building jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, +# jax-rocm-pjrt and for updating the requirements_lock.txt files. import argparse +import asyncio import logging import os import platform -import textwrap +import sys -from tools import utils +from tools import command, utils +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) -def write_bazelrc(*, remote_build, - cuda_version, cudnn_version, rocm_toolkit_path, - cpu, cuda_compute_capabilities, - rocm_amdgpu_targets, target_cpu_features, - wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, python_version, - enable_cuda, enable_nccl, enable_rocm, - use_cuda_nvcc): - - with open("../.jax_configure.bazelrc", "w") as f: - if not remote_build: - f.write(textwrap.dedent("""\ - build --strategy=Genrule=standalone - """)) - - if use_clang: - f.write(f'build --action_env CLANG_COMPILER_PATH="{clang_path}"\n') - f.write(f'build --repo_env CC="{clang_path}"\n') - f.write(f'build --repo_env BAZEL_COMPILER="{clang_path}"\n') - f.write('build --copt=-Wno-error=unused-command-line-argument\n') - if clang_major_version in (16, 17, 18): - # Necessary due to XLA's old version of upb. See: - # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 - f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - - if rocm_toolkit_path: - f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" - .format(rocm_toolkit_path=rocm_toolkit_path)) - if rocm_amdgpu_targets: - f.write( - f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"\n') - if cpu is not None: - f.write(f"build --cpu={cpu}\n") - - if target_cpu_features == "release": - if wheel_cpu == "x86_64": - f.write("build --config=avx_windows\n" if utils.is_windows() - else "build --config=avx_posix\n") - elif target_cpu_features == "native": - if utils.is_windows(): - print("--target_cpu_features=native is not supported on Windows; ignoring.") - else: - f.write("build --config=native_arch_posix\n") - - if enable_mkl_dnn: - f.write("build --config=mkl_open_source_only\n") - if enable_cuda: - f.write("build --config=cuda\n") - if use_cuda_nvcc: - f.write("build --config=build_cuda_with_nvcc\n") - else: - f.write("build --config=build_cuda_with_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if cuda_version: - f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') - if enable_rocm: - f.write("build --config=rocm_base\n") - if not enable_nccl: - f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=rocm\n") - f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n") - if python_version: - f.write( - "build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format( - python_version=python_version)) BANNER = r""" _ _ __ __ | | / \ \ \/ / @@ -112,421 +42,578 @@ def write_bazelrc(*, remote_build, """ EPILOG = """ +From the root directory of the JAX repository, run + python build/build.py [jaxlib | jax-cuda-plugin | jax-cuda-pjrt | jax-rocm-plugin | jax-rocm-pjrt] -From the 'build' directory in the JAX repository, run - python build.py + to build one of: jaxlib, jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt or - python3 build.py -to download and build JAX's XLA (jaxlib) dependency. + python build/build.py requirements_update to update the requirements_lock.txt """ +# Define the build target for each artifact. +ARTIFACT_BUILD_TARGET_DICT = { + "jaxlib": "//jaxlib/tools:build_wheel", + "jax-cuda-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-cuda-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", + "jax-rocm-plugin": "//jaxlib/tools:build_gpu_kernels_wheel", + "jax-rocm-pjrt": "//jaxlib/tools:build_gpu_plugin_wheel", +} -def _parse_string_as_bool(s): - """Parses a string as a boolean argument.""" - lower = s.lower() - if lower == "true": - return True - elif lower == "false": - return False - else: - raise ValueError(f"Expected either 'true' or 'false'; got {s}") +def add_python_version_argument(parser: argparse.ArgumentParser): + parser.add_argument( + "--python_version", + type=str, + choices=["3.10", "3.11", "3.12", "3.13"], + default=f"{sys.version_info.major}.{sys.version_info.minor}", + help= + """ + Hermetic Python version to use. Default is to use the version of the + Python binary that executed the CLI. + """, + ) -def add_boolean_argument(parser, name, default=False, help_str=None): - """Creates a boolean flag.""" - group = parser.add_mutually_exclusive_group() - group.add_argument( - "--" + name, - nargs="?", - default=default, - const=True, - type=_parse_string_as_bool, - help=help_str) - group.add_argument("--no" + name, dest=name, action="store_false") +def add_cuda_version_argument(parser: argparse.ArgumentParser): + parser.add_argument( + "--cuda_version", + type=str, + default=None, + help= + """ + Hermetic CUDA version to use. Default is to use the version specified + in the .bazelrc. + """, + ) -def _get_editable_output_paths(output_path): - """Returns the paths to the editable wheels.""" - return ( - os.path.join(output_path, "jaxlib"), - os.path.join(output_path, "jax_gpu_pjrt"), - os.path.join(output_path, "jax_gpu_plugin"), + +def add_cudnn_version_argument(parser: argparse.ArgumentParser): + parser.add_argument( + "--cudnn_version", + type=str, + default=None, + help= + """ + Hermetic cuDNN version to use. Default is to use the version specified + in the .bazelrc. + """, ) -def main(): - cwd = os.getcwd() - parser = argparse.ArgumentParser( - description="Builds jaxlib from source.", epilog=EPILOG) - add_boolean_argument( - parser, - "verbose", - default=False, - help_str="Should we produce verbose debugging output?") +def add_disable_nccl_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--bazel_path", - help="Path to the Bazel binary to use. The default is to find bazel via " - "the PATH; if none is found, downloads a fresh copy of bazel from " - "GitHub.") + "--disable_nccl", + action="store_true", + help="Should NCCL be disabled?", + ) + + +def add_cuda_compute_capabilities_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--python_bin_path", - help="Path to Python binary whose version to match while building with " - "hermetic python. The default is the Python interpreter used to run the " - "build script. DEPRECATED: use --python_version instead.") + "--cuda_compute_capabilities", + type=str, + default=None, + help= + """ + A comma-separated list of CUDA compute capabilities to support. Default + is to use the values specified in the .bazelrc. + """, + ) + + +def add_build_cuda_with_clang_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--target_cpu_features", - choices=["release", "native", "default"], - default="release", - help="What CPU features should we target? 'release' enables CPU " - "features that should be enabled for a release build, which on " - "x86-64 architectures enables AVX. 'native' enables " - "-march=native, which generates code targeted to use all " - "features of the current machine. 'default' means don't opt-in " - "to any architectural features and use whatever the C compiler " - "generates by default.") - add_boolean_argument( - parser, - "use_clang", - default = "true", - help_str=( - "DEPRECATED: This flag is redundant because clang is " - "always used as default compiler." - ), + "--build_cuda_with_clang", + action="store_true", + help=""" + Should CUDA code be compiled using Clang? The default behavior is to + compile CUDA with NVCC. Ignored if --use_ci_bazelrc_flags is set, CI + builds always build CUDA with NVCC in CI builds. + """, ) + + +def add_rocm_version_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--clang_path", - help=( - "Path to clang binary to use. The default is " - "to find clang via the PATH." - ), - ) - add_boolean_argument( - parser, - "enable_mkl_dnn", - default=True, - help_str="Should we build with MKL-DNN enabled?", - ) - add_boolean_argument( - parser, - "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." - ) - add_boolean_argument( - parser, - "use_cuda_nvcc", - default=True, - help_str=( - "Should we build CUDA code using NVCC compiler driver? The default value " - "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " - "by clang compiler." - ), - ) - add_boolean_argument( - parser, - "build_gpu_plugin", - default=False, - help_str=( - "Are we building the gpu plugin in addition to jaxlib? The GPU " - "plugin is still experimental and is not ready for use yet." - ), + "--rocm_version", + type=str, + default="60", + help="ROCm version to use", ) + + +def add_rocm_amdgpu_targets_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--build_gpu_kernel_plugin", - choices=["cuda", "rocm"], - default="", - help=( - "Specify 'cuda' or 'rocm' to build the respective kernel plugin." - " When this flag is set, jaxlib will not be built." - ), - ) - add_boolean_argument( - parser, - "build_gpu_pjrt_plugin", - default=False, - help_str=( - "Are we building the cuda/rocm pjrt plugin? jaxlib will not be built " - "when this flag is True." - ), + "--rocm_amdgpu_targets", + type=str, + default="gfx900,gfx906,gfx908,gfx90a,gfx1030", + help="A comma-separated list of ROCm amdgpu targets to support.", ) + + +def add_rocm_path_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--gpu_plugin_cuda_version", - choices=["12"], - default="12", - help="Which CUDA major version the gpu plugin is for.") + "--rocm_path", + type=str, + default="", + help="Path to the ROCm toolkit.", + ) + + +def add_requirements_nightly_update_argument(parser: argparse.ArgumentParser): parser.add_argument( - "--gpu_plugin_rocm_version", - choices=["60"], - default="60", - help="Which ROCM major version the gpu plugin is for.") - add_boolean_argument( - parser, - "enable_rocm", - help_str="Should we build with ROCm enabled?") - add_boolean_argument( - parser, - "enable_nccl", - default=True, - help_str="Should we build with NCCL enabled? Has no effect for non-CUDA " - "builds.") - add_boolean_argument( - parser, - "remote_build", - default=False, - help_str="Should we build with RBE (Remote Build Environment)?") + "--nightly_update", + action="store_true", + help=""" + If true, updates requirements_lock.txt for a corresponding version of + Python and will consider dev, nightly and pre-release versions of + packages. + """, + ) + + +def add_global_arguments(parser: argparse.ArgumentParser): + """Adds all the global arguments that applies to all the CLI subcommands.""" parser.add_argument( - "--cuda_version", - default=None, - help="CUDA toolkit version, e.g., 12.3.2") + "--bazel_path", + type=str, + default="", + help=""" + Path to the Bazel binary to use. The default is to find bazel via the + PATH; if none is found, downloads a fresh copy of Bazel from GitHub. + """, + ) + parser.add_argument( - "--cudnn_version", - default=None, - help="CUDNN version, e.g., 8.9.7.29") - # Caution: if changing the default list of CUDA capabilities, you should also - # update the list in .bazelrc, which is used for wheel builds. + "--bazel_startup_options", + action="append", + default=[], + help=""" + Additional startup options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_startup_options='--nobatch' + """, + ) + parser.add_argument( - "--cuda_compute_capabilities", - default=None, - help="A comma-separated list of CUDA compute capabilities to support.") + "--bazel_build_options", + action="append", + default=[], + help=""" + Additional build options to pass to Bazel, can be specified multiple + times to pass multiple options. + E.g. --bazel_build_options='--local_resources=HOST_CPUS' + """, + ) + parser.add_argument( - "--rocm_amdgpu_targets", - default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100", - help="A comma-separated list of ROCm amdgpu targets to support.") + "--dry_run", + action="store_true", + help="Prints the Bazel command that is going will be executed.", + ) + parser.add_argument( - "--rocm_path", - default=None, - help="Path to the ROCm toolkit.") + "--verbose", + action="store_true", + help="Produce verbose output for debugging.", + ) + + +def add_artifact_subcommand_global_arguments(parser: argparse.ArgumentParser): + """Adds all the global arguments that applies to the artifact subcommands.""" parser.add_argument( - "--bazel_startup_options", - action="append", default=[], - help="Additional startup options to pass to bazel.") + "--use_ci_bazelrc_flags", + action="store_true", + help=""" + When set, the CLI will assume the build is being run in CI or CI like + environment and will use the "rbe_/ci_" configs in the .bazelrc. These + configs apply release features and set a custom C++ Clang toolchain. + Only supported for jaxlib and CUDA builds. + """, + ) + parser.add_argument( - "--bazel_options", - action="append", default=[], - help="Additional options to pass to the main Bazel command to be " - "executed, e.g. `run`.") + "--editable", + action="store_true", + help="Create an 'editable' build instead of a wheel.", + ) + parser.add_argument( - "--output_path", - default=os.path.join(cwd, "dist"), - help="Directory to which the jaxlib wheel should be written") + "--disable_mkl_dnn", + action="store_true", + help=""" + Disables MKL-DNN. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc + flags enable MKL-DNN as default. + """, + ) + parser.add_argument( "--target_cpu", default=None, - help="CPU platform to target. Default is the same as the host machine. " - "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") + help="CPU platform to target. Default is the same as the host machine. ", + ) + parser.add_argument( - "--editable", - action="store_true", - help="Create an 'editable' jaxlib build instead of a wheel.") + "--target_cpu_features", + choices=["release", "native", "default"], + default="release", + help=""" + What CPU features should we target? Release enables CPU features that + should be enabled for a release build, which on x86-64 architectures + enables AVX. Native enables -march=native, which generates code targeted + to use all features of the current machine. Default means don't opt-in + to any architectural features and use whatever the C compiler generates + by default. Ignored if --use_ci_bazelrc_flags is set, CI bazelrc flags + enable release CPU features as default. + """, + ) + parser.add_argument( - "--python_version", - default=None, - help="hermetic python version, e.g., 3.10") - add_boolean_argument( - parser, - "configure_only", - default=False, - help_str="If true, writes a .bazelrc file but does not build jaxlib.") - add_boolean_argument( - parser, - "requirements_update", - default=False, - help_str="If true, writes a .bazelrc and updates requirements_lock.txt " - "for a corresponding version of Python but does not build " - "jaxlib.") - add_boolean_argument( - parser, - "requirements_nightly_update", - default=False, - help_str="Same as update_requirements, but will consider dev, nightly " - "and pre-release versions of packages.") + "--clang_path", + type=str, + default="", + help=""" + Path to the Clang binary to use. Ignored if --use_ci_bazelrc_flags, CI + bazelrc flags set a custom Clang toolchain. + """, + ) + + parser.add_argument( + "--local_xla_path", + type=str, + default=os.environ.get("JAXCI_XLA_GIT_DIR", ""), + help=""" + Path to local XLA repository to use. If not set, Bazel uses the XLA at + the pinned version in workspace.bzl. + """, + ) + + parser.add_argument( + "--output_path", + type=str, + default=os.path.join(os.getcwd(), "dist"), + help="Directory to which the JAX wheel packages should be written.", + ) + + parser.add_argument( + "--configure_only", + action="store_true", + help=""" + If true, writes the Bazel options to the .jax_configure.bazelrc file but + does not build the artifacts. Ignored if --use_ci_bazelrc_flags is set. + """, + ) + + +async def main(): + parser = argparse.ArgumentParser( + description=r""" + CLI for building one of the following packages from source: jaxlib, + jax-cuda-plugin, jax-cuda-pjrt, jax-rocm-plugin, jax-rocm-pjrt and for + updating the requirements_lock.txt files + """, + epilog=EPILOG, + ) + + # Create subparsers for jax, jaxlib, plugin, pjrt and requirements_update + subparsers = parser.add_subparsers(dest="command", required=True) + + # requirements_update subcommand + requirements_update_parser = subparsers.add_parser( + "requirements_update", help="Updates the requirements_lock.txt files" + ) + add_python_version_argument(requirements_update_parser) + add_requirements_nightly_update_argument(requirements_update_parser) + add_global_arguments(requirements_update_parser) + + # jaxlib subcommand + jaxlib_parser = subparsers.add_parser( + "jaxlib", help="Builds the jaxlib package." + ) + add_python_version_argument(jaxlib_parser) + add_artifact_subcommand_global_arguments(jaxlib_parser) + add_global_arguments(jaxlib_parser) + + # jax-cuda-plugin subcommand + cuda_plugin_parser = subparsers.add_parser( + "jax-cuda-plugin", help="Builds the jax-cuda-plugin package." + ) + add_python_version_argument(cuda_plugin_parser) + add_build_cuda_with_clang_argument(cuda_plugin_parser) + add_cuda_version_argument(cuda_plugin_parser) + add_cudnn_version_argument(cuda_plugin_parser) + add_cuda_compute_capabilities_argument(cuda_plugin_parser) + add_disable_nccl_argument(cuda_plugin_parser) + add_artifact_subcommand_global_arguments(cuda_plugin_parser) + add_global_arguments(cuda_plugin_parser) + + # jax-cuda-pjrt subcommand + cuda_pjrt_parser = subparsers.add_parser( + "jax-cuda-pjrt", help="Builds the jax-cuda-pjrt package." + ) + add_build_cuda_with_clang_argument(cuda_pjrt_parser) + add_cuda_version_argument(cuda_pjrt_parser) + add_cudnn_version_argument(cuda_pjrt_parser) + add_cuda_compute_capabilities_argument(cuda_pjrt_parser) + add_disable_nccl_argument(cuda_pjrt_parser) + add_artifact_subcommand_global_arguments(cuda_pjrt_parser) + add_global_arguments(cuda_pjrt_parser) + + # jax-rocm-plugin subcommand + rocm_plugin_parser = subparsers.add_parser( + "jax-rocm-plugin", help="Builds the jax-rocm-plugin package." + ) + add_python_version_argument(rocm_plugin_parser) + add_rocm_version_argument(rocm_plugin_parser) + add_rocm_amdgpu_targets_argument(rocm_plugin_parser) + add_rocm_path_argument(rocm_plugin_parser) + add_disable_nccl_argument(rocm_plugin_parser) + add_artifact_subcommand_global_arguments(rocm_plugin_parser) + add_global_arguments(rocm_plugin_parser) + + # jax-rocm-pjrt subcommand + rocm_pjrt_parser = subparsers.add_parser( + "jax-rocm-pjrt", help="Builds the jax-rocm-pjrt package." + ) + add_rocm_version_argument(rocm_pjrt_parser) + add_rocm_amdgpu_targets_argument(rocm_pjrt_parser) + add_rocm_path_argument(rocm_pjrt_parser) + add_disable_nccl_argument(rocm_pjrt_parser) + add_artifact_subcommand_global_arguments(rocm_pjrt_parser) + add_global_arguments(rocm_pjrt_parser) + + arch = platform.machine() + # Switch to lower case to match the case for the "ci_"/"rbe_" configs in the + # .bazelrc. + os_name = platform.system().lower() args = parser.parse_args() - logging.basicConfig() + logger.info("%s", BANNER) + if args.verbose: - logger.setLevel(logging.DEBUG) + logging.getLogger().setLevel(logging.DEBUG) + logger.info("Verbose logging enabled") + + logger.info( + "Building %s for %s %s...", + args.command, + os_name, + arch, + ) + + bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) - if args.enable_cuda and args.enable_rocm: - parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") + logging.debug("Bazel path: %s", bazel_path) + logging.debug("Bazel version: %s", bazel_version) - print(BANNER) + executor = command.SubprocessExecutor() - output_path = os.path.abspath(args.output_path) - os.chdir(os.path.dirname(__file__ or args.prog) or '.') + # Start constructing the Bazel command + bazel_command = command.CommandBuilder(bazel_path) + + if args.bazel_startup_options: + logging.debug( + "Additional Bazel startup options: %s", args.bazel_startup_options + ) + for option in args.bazel_startup_options: + bazel_command.append(option) + + bazel_command.append("run") + + if hasattr(args, "python_version"): + logging.debug("Hermetic Python version: %s", args.python_version) + bazel_command.append( + f"--repo_env=HERMETIC_PYTHON_VERSION={args.python_version}" + ) + + if args.command == "requirements_update": + if args.bazel_build_options: + logging.debug( + "Using additional build options: %s", args.bazel_build_options + ) + for option in args.bazel_build_options: + bazel_command.append(option) + + if args.nightly_update: + logging.debug( + "--nightly_update is set. Bazel will run" + " //build:requirements_nightly.update" + ) + bazel_command.append("//build:requirements_nightly.update") + else: + bazel_command.append("//build:requirements.update") + + await executor.run(bazel_command.command, args.dry_run) + sys.exit(0) - host_cpu = platform.machine() wheel_cpus = { "darwin_arm64": "arm64", "darwin_x86_64": "x86_64", "ppc": "ppc64le", "aarch64": "aarch64", } - # TODO(phawkins): support other bazel cpu overrides. - wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None - else host_cpu) - - # Find a working Bazel. - bazel_path, bazel_version = utils.get_bazel_path(args.bazel_path) - print(f"Bazel binary path: {bazel_path}") - print(f"Bazel version: {bazel_version}") + target_cpu = ( + wheel_cpus[args.target_cpu] if args.target_cpu is not None else arch + ) - if args.python_version: - python_version = args.python_version + # Enable color in the Bazel output. + bazel_command.append("--color=yes") + + # If running in CI, we use the "ci_"/"rbe_" configs in the .bazelrc. + # These set a custom C++ Clang toolchain and the CUDA compiler to NVCC + # When not running in CI, we detect the path to Clang binary and pass it + # to Bazel to use as the C++ compiler. NVCC is used as the CUDA compiler + # unless the user explicitly sets --config=build_cuda_with_clang. + if args.use_ci_bazelrc_flags and "rocm" not in args.command: + bazelrc_config = utils.get_ci_bazelrc_config(os_name, arch.lower(), args.command) + logging.debug("--use_ci_bazelrc_flags is set, using --config=%s from .bazelrc", bazelrc_config) + bazel_command.append(f"--config={bazelrc_config}") else: - python_bin_path = utils.get_python_bin_path(args.python_bin_path) - print(f"Python binary path: {python_bin_path}") - python_version = utils.get_python_version(python_bin_path) - print("Python version: {}".format(".".join(map(str, python_version)))) - utils.check_python_version(python_version) - python_version = ".".join(map(str, python_version)) - - print("Use clang: {}".format("yes" if args.use_clang else "no")) - clang_path = args.clang_path - clang_major_version = None - if args.use_clang: - if not clang_path: - clang_path = utils.get_clang_path_or_exit() - print(f"clang path: {clang_path}") - clang_major_version = utils.get_clang_major_version(clang_path) - - print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) - print(f"Target CPU: {wheel_cpu}") - print(f"Target CPU features: {args.target_cpu_features}") - - rocm_toolkit_path = args.rocm_path - print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) - if args.enable_cuda: - if args.cuda_compute_capabilities is not None: - print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") + clang_path = args.clang_path or utils.get_clang_path_or_exit() + logging.debug("Using Clang as the compiler, clang path: %s", clang_path) + # Use double quotes around clang path to avoid path issues on Windows. + bazel_command.append(f'--action_env=CLANG_COMPILER_PATH="{clang_path}"') + bazel_command.append(f'--repo_env=CC="{clang_path}"') + bazel_command.append(f'--repo_env=BAZEL_COMPILER="{clang_path}"') + bazel_command.append("--config=clang") + + if not args.disable_mkl_dnn: + logging.debug("Enabling MKL DNN") + bazel_command.append("--config=mkl_open_source_only") + + if "cuda" in args.command: + bazel_command.append("--config=cuda") + bazel_command.append( + f'--action_env=CLANG_CUDA_COMPILER_PATH="{clang_path}"' + ) + if args.build_cuda_with_clang: + logging.debug("Building CUDA with Clang") + bazel_command.append("--config=build_cuda_with_clang") + else: + logging.debug("Building CUDA with NVCC") + bazel_command.append("--config=build_cuda_with_nvcc") + + if args.target_cpu_features == "release": + logging.debug( + "Using release cpu features: --config=avx_%s", + "windows" if os_name == "windows" else "posix", + ) + if arch in ["x86_64", "AMD64"]: + bazel_command.append( + "--config=avx_windows" + if os_name == "windows" + else "--config=avx_posix" + ) + elif args.target_cpu_features == "native": + if os_name == "windows": + logger.warning( + "--target_cpu_features=native is not supported on Windows;" + " ignoring." + ) + else: + logging.debug("Using native cpu features: --config=native_arch_posix") + bazel_command.append("--config=native_arch_posix") + else: + logging.debug("Using default cpu features") + + if args.target_cpu: + logging.debug("Target CPU: %s", args.target_cpu) + bazel_command.append(f"--cpu={args.target_cpu}") + + if hasattr(args, "disable_nccl") and args.disable_nccl: + logging.debug("Disabling NCCL") + bazel_command.append("--config=nonccl") + + if "cuda" in args.command: if args.cuda_version: - print(f"CUDA version: {args.cuda_version}") + logging.debug("Hermetic CUDA version: %s", args.cuda_version) + bazel_command.append( + f"--repo_env=HERMETIC_CUDA_VERSION={args.cuda_version}" + ) if args.cudnn_version: - print(f"CUDNN version: {args.cudnn_version}") - print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no")) - - print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no")) - if args.enable_rocm: - if rocm_toolkit_path: - print(f"ROCm toolkit path: {rocm_toolkit_path}") - print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}") - - write_bazelrc( - remote_build=args.remote_build, - cuda_version=args.cuda_version, - cudnn_version=args.cudnn_version, - rocm_toolkit_path=rocm_toolkit_path, - cpu=args.target_cpu, - cuda_compute_capabilities=args.cuda_compute_capabilities, - rocm_amdgpu_targets=args.rocm_amdgpu_targets, - target_cpu_features=args.target_cpu_features, - wheel_cpu=wheel_cpu, - enable_mkl_dnn=args.enable_mkl_dnn, - use_clang=args.use_clang, - clang_path=clang_path, - clang_major_version=clang_major_version, - python_version=python_version, - enable_cuda=args.enable_cuda, - enable_nccl=args.enable_nccl, - enable_rocm=args.enable_rocm, - use_cuda_nvcc=args.use_cuda_nvcc, - ) - - if args.requirements_update or args.requirements_nightly_update: - if args.requirements_update: - task = "//build:requirements.update" - else: # args.requirements_nightly_update - task = "//build:requirements_nightly.update" - update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", task, *args.bazel_options]) - print(" ".join(update_command)) - utils.shell(update_command) - return - - if args.configure_only: - return - - print("\nBuilding XLA and installing it in the jaxlib source tree...") - - command_base = ( - bazel_path, - *args.bazel_startup_options, - "run", - "--verbose_failures=true", - *args.bazel_options, - ) - - if args.build_gpu_plugin and args.editable: - output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( - _get_editable_output_paths(output_path) + logging.debug("Hermetic cuDNN version: %s", args.cudnn_version) + bazel_command.append( + f"--repo_env=HERMETIC_CUDNN_VERSION={args.cudnn_version}" + ) + if args.cuda_compute_capabilities: + logging.debug( + "Hermetic CUDA compute capabilities: %s", + args.cuda_compute_capabilities, + ) + bazel_command.append( + f"--repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES={args.cuda_compute_capabilities}" + ) + + if "rocm" in args.command: + bazel_command.append("--config=rocm") + + if args.rocm_path: + logging.debug("ROCm tookit path: %s", args.rocm_path) + bazel_command.append(f'--action_env=ROCM_PATH="{args.rocm_path}"') + if args.rocm_amdgpu_targets: + logging.debug("ROCm AMD GPU targets: %s", args.rocm_amdgpu_targets) + bazel_command.append( + f"--action_env=TF_ROCM_AMDGPU_TARGETS={args.rocm_amdgpu_targets}" + ) + + if args.local_xla_path: + logging.debug("Local XLA path: %s", args.local_xla_path) + bazel_command.append(f'--override_repository=xla="{args.local_xla_path}"') + + if args.bazel_build_options: + logging.debug( + "Additional Bazel build options: %s", args.bazel_build_options ) - else: - output_path_jaxlib = output_path - output_path_jax_pjrt = output_path - output_path_jax_kernel = output_path - - if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: - build_cpu_wheel_command = [ - *command_base, - "//jaxlib/tools:build_wheel", - "--", - f"--output_path={output_path_jaxlib}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.build_gpu_plugin: - build_cpu_wheel_command.append("--skip_gpu_kernels") - if args.editable: - build_cpu_wheel_command.append("--editable") - print(" ".join(build_cpu_wheel_command)) - utils.shell(build_cpu_wheel_command) - - if args.build_gpu_plugin or (args.build_gpu_kernel_plugin == "cuda") or \ - (args.build_gpu_kernel_plugin == "rocm"): - build_gpu_kernels_command = [ - *command_base, - "//jaxlib/tools:build_gpu_kernels_wheel", - "--", - f"--output_path={output_path_jax_kernel}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_gpu_kernels_command.append(f"--enable-cuda={args.enable_cuda}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_gpu_kernels_command.append(f"--enable-rocm={args.enable_rocm}") - build_gpu_kernels_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") - else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_gpu_kernels_command.append("--editable") - print(" ".join(build_gpu_kernels_command)) - utils.shell(build_gpu_kernels_command) - - if args.build_gpu_plugin or args.build_gpu_pjrt_plugin: - build_pjrt_plugin_command = [ - *command_base, - "//jaxlib/tools:build_gpu_plugin_wheel", - "--", - f"--output_path={output_path_jax_pjrt}", - f"--jaxlib_git_hash={utils.get_githash()}", - f"--cpu={wheel_cpu}", - ] - if args.enable_cuda: - build_pjrt_plugin_command.append(f"--enable-cuda={args.enable_cuda}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_cuda_version}") - elif args.enable_rocm: - build_pjrt_plugin_command.append(f"--enable-rocm={args.enable_rocm}") - build_pjrt_plugin_command.append(f"--platform_version={args.gpu_plugin_rocm_version}") + for option in args.bazel_build_options: + bazel_command.append(option) + + if not args.use_ci_bazelrc_flags: + with open(".jax_configure.bazelrc", "w") as f: + jax_configure_options = utils.get_jax_configure_bazel_options(bazel_command.parameters) + if not jax_configure_options: + logging.error("Error retrieving the Bazel options to be written to .jax_configure.bazelrc, exiting.") + sys.exit(1) + f.write(jax_configure_options) + logging.debug("Bazel options written to .jax_configure.bazelrc") + if args.configure_only: + logging.debug("--configure_only is set, exiting without running any Bazel commands.") + sys.exit(0) + + # Append the build target to the Bazel command. + build_target = ARTIFACT_BUILD_TARGET_DICT[args.command] + bazel_command.append(build_target) + + bazel_command.append("--") + + output_path = args.output_path + logger.debug("Artifacts output directory: %s", output_path) + + if args.editable: + logger.debug("Building an editable build") + output_path = os.path.join(output_path, args.command) + bazel_command.append("--editable") + + bazel_command.append(f'--output_path="{output_path}"') + bazel_command.append(f"--cpu={target_cpu}") + + if "cuda" in args.command: + bazel_command.append("--enable-cuda=True") + if args.cuda_version: + cuda_major_version = args.cuda_version.split(".")[0] else: - raise ValueError("Unsupported GPU plugin backend. Choose either 'cuda' or 'rocm'.") - if args.editable: - build_pjrt_plugin_command.append("--editable") - print(" ".join(build_pjrt_plugin_command)) - utils.shell(build_pjrt_plugin_command) + cuda_major_version = utils.get_cuda_major_version() + bazel_command.append(f"--platform_version={cuda_major_version}") + + if "rocm" in args.command: + bazel_command.append("--enable-rocm=True") + bazel_command.append(f"--platform_version={args.rocm_version}") + + git_hash = utils.get_githash() + bazel_command.append(f"--jaxlib_git_hash={git_hash}") - utils.shell([bazel_path] + args.bazel_startup_options + ["shutdown"]) + await executor.run(bazel_command.command, args.dry_run) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/build/tools/command.py b/build/tools/command.py new file mode 100644 index 000000000000..e601259cc6f1 --- /dev/null +++ b/build/tools/command.py @@ -0,0 +1,104 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Helper script for the JAX build CLI for running subprocess commands. +import asyncio +import dataclasses +import datetime +import os +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + +class CommandBuilder: + def __init__(self, base_command: str): + self.command = base_command + self.parameters = [base_command] + + def append(self, parameter: str): + self.command += " {}".format(parameter) + self.parameters.append(parameter) + return self + +@dataclasses.dataclass +class CommandResult: + """ + Represents the result of executing a subprocess command. + """ + + command: str + return_code: int = 2 # Defaults to not successful + logs: str = "" + start_time: datetime.datetime = dataclasses.field( + default_factory=datetime.datetime.now + ) + end_time: Optional[datetime.datetime] = None + +class SubprocessExecutor: + """ + Manages execution of subprocess commands with reusable environment and logging. + """ + + def __init__(self, environment: Dict[str, str] = None): + """ + + Args: + environment: + """ + self.environment = environment or dict(os.environ) + + async def run(self, cmd: str, dry_run: bool = False) -> CommandResult: + """ + Executes a subprocess command. + + Args: + cmd: The command to execute. + dry_run: If True, prints the command instead of executing it. + + Returns: + A CommandResult instance. + """ + result = CommandResult(command=cmd) + if dry_run: + logger.info("[DRY RUN] %s", cmd) + result.return_code = 0 # Dry run is a success + return result + + logger.info("[EXECUTING] %s", cmd) + + process = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=self.environment, + ) + + async def log_stream(stream, result: CommandResult): + while True: + line_bytes = await stream.readline() + if not line_bytes: + break + line = line_bytes.decode().rstrip() + result.logs += line + logger.info("%s", line) + + await asyncio.gather( + log_stream(process.stdout, result), log_stream(process.stderr, result) + ) + + result.return_code = await process.wait() + result.end_time = datetime.datetime.now() + logger.debug("Command finished with return code %s", result.return_code) + return result \ No newline at end of file diff --git a/build/tools/utils.py b/build/tools/utils.py index 4c8765371316..e12c045a4d1b 100644 --- a/build/tools/utils.py +++ b/build/tools/utils.py @@ -28,25 +28,6 @@ logger = logging.getLogger(__name__) -def is_windows(): - return sys.platform.startswith("win32") - -def shell(cmd): - try: - logger.info("shell(): %s", cmd) - output = subprocess.check_output(cmd) - except subprocess.CalledProcessError as e: - logger.info("subprocess raised: %s", e) - if e.output: - print(e.output) - raise - except Exception as e: - logger.info("subprocess raised: %s", e) - raise - return output.decode("UTF-8").strip() - - -# Bazel BAZEL_BASE_URI = "https://github.com/bazelbuild/bazel/releases/download/6.5.0/" BazelPackage = collections.namedtuple( "BazelPackage", ["base_uri", "file", "sha256"] @@ -180,7 +161,12 @@ def get_bazel_path(bazel_path_flag): def get_bazel_version(bazel_path): try: - version_output = shell([bazel_path, "--version"]) + version_output = subprocess.run( + [bazel_path, "--version"], + encoding="utf-8", + capture_output=True, + check=True, + ).stdout.strip() except (subprocess.CalledProcessError, OSError): return None match = re.search(r"bazel *([0-9\\.]+)", version_output) @@ -203,47 +189,70 @@ def get_clang_path_or_exit(): sys.exit(-1) -def get_clang_major_version(clang_path): - clang_version_proc = subprocess.run( - [clang_path, "-E", "-P", "-"], - input="__clang_major__", - check=True, - capture_output=True, - text=True, - ) - major_version = int(clang_version_proc.stdout) +def get_cuda_major_version(): + """Extract the CUDA major version from the .bazelrc""" + with open(".bazelrc", "r") as f: + for line in f: + match = re.search(r'HERMETIC_CUDA_VERSION="([^"]+)"', line) + if match: + cuda_version=match.group(1) + return cuda_version.split(".")[0] + return None - return major_version +def get_ci_bazelrc_config(os_name: str, arch: str, artifact: str): + """Returns the bazelrc config for the given architecture and OS. -# Python -def get_python_bin_path(python_bin_path_flag): - """Returns the path to the Python interpreter to use.""" - path = python_bin_path_flag or sys.executable - return path.replace(os.sep, "/") + Used in CI builds to retrieve either the "ci_"/"rbe_" configs from the + .bazelrc + """ + bazelrc_config = f"{os_name}_{arch}" -def get_python_version(python_bin_path): - version_output = shell([ - python_bin_path, - "-c", - ( - 'import sys; print("{}.{}".format(sys.version_info[0], ' - "sys.version_info[1]))" - ), - ]) - major, minor = map(int, version_output.split(".")) - return major, minor + # If building on Linux x86 or Windows, use the "rbe_" flags otherwise use + # the "ci_" (non-rbe) flags + if (os_name == "linux" and arch == "x86_64") or ( + os_name == "windows" and arch == "amd64" + ): + bazelrc_config = "rbe_" + bazelrc_config + else: + bazelrc_config = "ci_" + bazelrc_config + + # When building jax-cuda-plugin or jax-cuda-pjrt, append "_cuda" to the + # bazelrc config to use the CUDA specific configs. + if "cuda" in artifact: + bazelrc_config = bazelrc_config + "_cuda" + + return bazelrc_config + + +def get_jax_configure_bazel_options(bazel_command: list[str]): + """Returns the bazel options to be written to .jax_configure.bazelrc.""" + # Get the index of the "run" parameter. Build options will come after "run" so + # we find the index of "run" and filter everything after it. + start = bazel_command.index("run") + jax_configure_bazel_options = "" + try: + for i in range(start + 1, len(bazel_command)): + bazel_flag = bazel_command[i] + # Replace all backslashes with double backslashes to avoid escaping issues + # when running on Windows. + if platform.system() == "Windows": + bazel_flag = bazel_flag.replace("\\", "\\\\") + jax_configure_bazel_options += f"build {bazel_flag}\n" + return jax_configure_bazel_options + except ValueError: + logging.error("Unable to find index for 'run' in the Bazel command") + return "" -def check_python_version(python_version): - if python_version < (3, 10): - print("ERROR: JAX requires Python 3.10 or newer, found ", python_version) - sys.exit(-1) def get_githash(): try: return subprocess.run( - ["git", "rev-parse", "HEAD"], encoding="utf-8", capture_output=True + ["git", "rev-parse", "HEAD"], + encoding="utf-8", + capture_output=True, + check=True, ).stdout.strip() except OSError: return "" diff --git a/docs/developer.md b/docs/developer.md index cbb60382b7f1..c15d810b4466 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -63,7 +63,7 @@ To build `jaxlib` from source, you must also install some prerequisites: To build `jaxlib` for CPU or TPU, you can run: ``` -python build/build.py +python build/build.py jaxlib --verbose pip install dist/*.whl # installs jaxlib (includes XLA) ``` @@ -71,7 +71,7 @@ To build a wheel for a version of Python different from your current system installation pass `--python_version` flag to the build command: ``` -python build/build.py --python_version=3.12 +python build/build.py jaxlib --python_version=3.12 --verbose ``` The rest of this document assumes that you are building for Python version @@ -81,13 +81,15 @@ version, simply append `--python_version=` flag every time you call installation regardless of whether the `--python_version` parameter is passed or not. -There are two ways to build `jaxlib` with CUDA support: (1) use -`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda -support, or (2) use -`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` +If you would like to build `jaxlib` with CUDA support: Run +``` +python build/build.py jaxlib +python build/build.py jax-cuda-plugin +python build/build.py jax-cuda-pjrt +``` to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and -jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and -clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--build_cuda_with_clang` flag. See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you @@ -102,11 +104,16 @@ current directory. target dependencies. To download the specific versions of CUDA/CUDNN redistributions, you can use - the following command: + the `--cuda_version` and `--cudnn_version` flags: ```bash - python build/build.py --enable_cuda \ - --cuda_version=12.3.2 --cudnn_version=9.1.1 + python build/build.py jax-cuda-plugin --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 + ``` + or + ```bash + python build/build.py jax-cuda-pjrt --cuda_version=12.3.2 \ + --cudnn_version=9.1.1 ``` Please note that these parameters are optional: by default Bazel will @@ -118,7 +125,7 @@ current directory. the following command: ```bash - python build/build.py --enable_cuda \ + python build/build.py jax-cuda-plugin \ --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" @@ -141,7 +148,7 @@ ways to do this: line flag to `build.py` as follows: ``` - python build/build.py --bazel_options=--override_repository=xla=/path/to/xla + python build/build.py jaxlib --local_xla_path=/path/to/xla ``` - modify the `WORKSPACE` file in the root of the JAX source tree to point to @@ -203,12 +210,16 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \ The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`, and selecting the appropriate options. -To build jaxlib with ROCM support, you can run the following build command, +To build jaxlib with ROCM support, you can run the following build commands, suitably adjusted for your paths and ROCM version. ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py jaxlib +python3 ./build/build.py jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 ``` +to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and +jax-rocm-pjrt) AMD's fork of the XLA repository may include fixes not present in the upstream XLA repository. If you experience problems with the upstream repository, you can @@ -221,7 +232,7 @@ git clone https://github.com/ROCm/xla.git and override the XLA repository with which JAX is built: ``` -python3 ./build/build.py --use_clang=true --clang_path=/usr/lib/llvm-18/bin/clang-18 --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --bazel_options=--override_repository=xla=/rel/xla/ --rocm_path=/opt/rocm-6.2.3 +python3 ./build/build.py jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/ ``` For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`. @@ -284,7 +295,7 @@ direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Alternatively, if you need more control, you may run the bazel command @@ -328,7 +339,7 @@ For example: ``` echo -e "\n$(realpath jaxlib-0.4.27.dev20240416-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` ### Specifying dependencies on nightly wheels @@ -338,7 +349,7 @@ dependencies we provide a special version of the dependency updater command as follows: ``` -python build/build.py --requirements_nightly_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 --nightly_update ``` Or, if you run `bazel` directly (the two commands are equivalent): @@ -469,10 +480,13 @@ or using pytest. ### Using Bazel -First, configure the JAX build by running: +First, configure the JAX build by using the `--configure_only` flag. Use the +`jaxlib` command for CPU tests and CUDA/ROCM for GPU for GPU tests: ``` -python build/build.py --configure_only +python build/build.py jaxlib --configure_only +python build/build.py jax-cuda-plugin --configure_only +python build/build.py jax-rocm-plugin --configure_only ``` You may pass additional options to `build.py` to configure the build; see the @@ -494,14 +508,14 @@ make it available in the hermetic Python. To install a specific version of ``` echo -e "\njaxlib >= 0.4.26" >> build/requirements.in -python build/build.py --requirements_update +python build/build.py requirements_update ``` Alternatively, to install `jaxlib` from a local wheel (assuming Python 3.12): ``` echo -e "\n$(realpath jaxlib-0.4.26-cp312-cp312-manylinux2014_x86_64.whl)" >> build/requirements.in -python build/build.py --requirements_update --python_version=3.12 +python build/build.py requirements_update --python_version=3.12 ``` Once you have `jaxlib` installed hermetically, run: